Unbalanced Optimal Transport#

This tutorial shows how to use ott to compute the solution of an unbalanced optimal transport (OT) problem using the Sinkhorn algorithm. The unbalanced OT problem with an entropic regularization is defined as:

\[ \arg \min_{P>0} \langle P,C \rangle - \varepsilon \text{KL}(P|ab^T) + \rho_a \text{KL}(P\textbf{1}_m|a) + \rho_b \text{KL}(P^T\textbf{1}_n|b) \]

where \(\rho_a\) and \(\rho_b\) regularize a discrepancy with respect to marginal constraints and \(\varepsilon\) accounts for the entropic regularization. The algorithm used to solve the unbalanced OT problem is the Sinkhorn algorithm whose steps are detailed in [Frogner et al., 2015]. Instead of \(\rho_a\) and \(\rho_b\), ott’s solver uses parameters \(\tau_a = \rho_a /(\varepsilon+ \rho_a)\) and \(\tau_b = \rho_b /(\varepsilon+ \rho_b)\). Setting either of these parameters to \(1\) corresponds to setting the corresponding \(\rho_a, \rho_b\) to \(\infty\) and solving the corresponding balanced LinearProblem.

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot
from ott.tools.gaussian_mixture import gaussian_mixture

Generate source and target distributions#

Let us first generate source and target distributions that correspond to the illustrative example of [Séjourné et al., 2022] (Figure 4).

def generate_data(
    rng: jax.Array, *, means: jnp.ndarray, cov: jnp.ndarray, n_samples: int
) -> jnp.ndarray:
    gmm = gaussian_mixture.GaussianMixture.from_mean_cov_component_weights(
        mean=means,
        cov=cov,
        component_weights=jnp.ones(len(means)) / len(means),
    )
    return gmm.sample(rng, n_samples)
rng, rng_source, rng_target = jax.random.split(jax.random.key(0), 3)

means_source = jnp.array(
    [
        [-1.6, -1.6],
        [-1, 0.25],
        [0.25, -1],
        [0.25, 1],
        [1, 0.25],
    ]
)
cov_source = jnp.array([0.01 * jnp.identity(2)] * 5)
x = generate_data(rng_source, means=means_source, cov=cov_source, n_samples=40)

means_target = jnp.array(
    [
        [1.6, 1.6],
        [-1, -0.25],
        [-0.25, -1],
        [-0.25, 1],
        [1, -0.25],
    ]
)
cov_target = jnp.array([0.01 * jnp.identity(2)] * 5)
y = generate_data(rng_target, means=means_target, cov=cov_target, n_samples=42)

Visualize the data#

plt.scatter(
    x[:, 0],
    x[:, 1],
    s=200,
    edgecolors="k",
    marker="o",
    label="Source samples",
)
plt.scatter(
    y[:, 0],
    y[:, 1],
    s=200,
    edgecolors="k",
    marker="X",
    label="Target samples",
)
plt.legend()
plt.show()
../../_images/419488e0adf5c75a1253f4edc7b56aa03c71074bd5cc4e26655d2ce11495d84b.png

The source and target distributions are both a mixture of five Gaussians. One can observe that the central four modes from the source can be naturally matched to the central four modes from the target while in each distribution the fifth mode is distant and can be seen as outliers.

Let us first visualize the result of the balanced OT in this setting.

Balanced OT mapping#

We first define the geometry of the problem and the Sinkhorn solver.

geom = pointcloud.PointCloud(x, y, epsilon=1e-3)
solver = sinkhorn.Sinkhorn()

Solving the balanced OT problem is equivalent to fixing tau_a = tau_b = 1.0 in the LinearProblem.

# define a balanaced linear problem associated to the geometry defined above
ot_prob = linear_problem.LinearProblem(geom, tau_a=1.0, tau_b=1.0)
# solve the OT problem
ot = solver(ot_prob)
# plot the computed transport plan
plott = plot.Plot(threshold=1e-2)
_ = plott(ot)
../../_images/2bcf2d4a317b4a17f78d7823f26a329ecd02142528fc3b36aedec6501758b7d2.png

In the solution of the balanced OT problem, the outliers interact with the other points. This is not suitable in several applications where only clusters that are close should interact. In the following, we will see that this mapping can be computed by solving the unbalanced OT problem.

Unbalanced OT mapping#

We will now define an unbalanced OT problem with tau_a = tau_b = 0.999 to relax both the source and target marginal constraints, respectively.

# define an unbalanced linear problem
ot_prob = linear_problem.LinearProblem(geom, tau_a=0.999, tau_b=0.999)
ot = solver(ot_prob)
plott = plot.Plot(threshold=1e-2)
_ = plott(ot)
../../_images/bde2a63f8e92a70f98092496bf57be7b17780012cbb4a82ec1c746b964ec76f4.png

This is the most natural solution: in this setting, only clusters that are close to each other interact.