Unbalanced Optimal Transport#
This tutorial shows how to use ott
to compute the solution of an unbalanced optimal transport (OT) problem using the Sinkhorn algorithm. The unbalanced OT problem with an entropic regularization is defined as:
where \(\rho_a\) and \(\rho_b\) regularize a discrepancy with respect to marginal constraints and \(\varepsilon\) accounts for the entropic regularization.
The algorithm used to solve the unbalanced OT problem is the Sinkhorn
algorithm whose steps are detailed in [Frogner et al., 2015].
Instead of \(\rho_a\) and \(\rho_b\), ott
’s solver uses parameters \(\tau_a = \rho_a /(\varepsilon+ \rho_a)\) and \(\tau_b = \rho_b /(\varepsilon+ \rho_b)\). Setting either of these parameters to \(1\) corresponds to setting the corresponding \(\rho_a, \rho_b\) to \(\infty\) and solving the corresponding balanced LinearProblem
.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot
from ott.tools.gaussian_mixture import gaussian_mixture
Generate source and target distributions#
Let us first generate source and target distributions that correspond to the illustrative example of [Séjourné et al., 2022] (Figure 4).
def generate_data(
rng: jax.Array, *, means: jnp.ndarray, cov: jnp.ndarray, n_samples: int
) -> jnp.ndarray:
gmm = gaussian_mixture.GaussianMixture.from_mean_cov_component_weights(
mean=means,
cov=cov,
component_weights=jnp.ones(len(means)) / len(means),
)
return gmm.sample(rng, n_samples)
rng, rng_source, rng_target = jax.random.split(jax.random.key(0), 3)
means_source = jnp.array(
[
[-1.6, -1.6],
[-1, 0.25],
[0.25, -1],
[0.25, 1],
[1, 0.25],
]
)
cov_source = jnp.array([0.01 * jnp.identity(2)] * 5)
x = generate_data(rng_source, means=means_source, cov=cov_source, n_samples=40)
means_target = jnp.array(
[
[1.6, 1.6],
[-1, -0.25],
[-0.25, -1],
[-0.25, 1],
[1, -0.25],
]
)
cov_target = jnp.array([0.01 * jnp.identity(2)] * 5)
y = generate_data(rng_target, means=means_target, cov=cov_target, n_samples=42)
Visualize the data#
plt.scatter(
x[:, 0],
x[:, 1],
s=200,
edgecolors="k",
marker="o",
label="Source samples",
)
plt.scatter(
y[:, 0],
y[:, 1],
s=200,
edgecolors="k",
marker="X",
label="Target samples",
)
plt.legend()
plt.show()
The source and target distributions are both a mixture of five Gaussians. One can observe that the central four modes from the source can be naturally matched to the central four modes from the target while in each distribution the fifth mode is distant and can be seen as outliers.
Let us first visualize the result of the balanced OT in this setting.
Balanced OT mapping#
We first define the geometry
of the problem and the Sinkhorn
solver.
geom = pointcloud.PointCloud(x, y, epsilon=1e-3)
solver = sinkhorn.Sinkhorn()
Solving the balanced OT problem is equivalent to fixing tau_a = tau_b = 1.0
in the LinearProblem
.
# define a balanaced linear problem associated to the geometry defined above
ot_prob = linear_problem.LinearProblem(geom, tau_a=1.0, tau_b=1.0)
# solve the OT problem
ot = solver(ot_prob)
# plot the computed transport plan
plott = plot.Plot(threshold=1e-2)
_ = plott(ot)
In the solution
of the balanced OT problem, the outliers interact with the other points. This is not suitable in several applications where only clusters that are close should interact. In the following, we will see that this mapping can be computed by solving the unbalanced OT problem.
Unbalanced OT mapping#
We will now define an unbalanced OT problem with tau_a = tau_b = 0.999
to relax both the source and target marginal constraints, respectively.
# define an unbalanced linear problem
ot_prob = linear_problem.LinearProblem(geom, tau_a=0.999, tau_b=0.999)
ot = solver(ot_prob)
plott = plot.Plot(threshold=1e-2)
_ = plott(ot)
This is the most natural solution: in this setting, only clusters that are close to each other interact.