Sinkhorn divergence gradient flows#
Let \(\mathrm{OT_\varepsilon}(\alpha, \beta)\) the entropic regularized OT distance between two distributions \(\alpha\) and \(\beta\). One issue with \(\mathrm{OT_\varepsilon}\) is that \(\mathrm{OT_\varepsilon}(\alpha, \alpha)\) is not equal to 0.
The Sinkhorn divergence, defined in [Genevay et al., 2018] as \(\mathrm{S}_\varepsilon(\alpha, \beta) = \mathrm{OT_\varepsilon}(\alpha, \beta) - \frac{1}{2}\mathrm{OT_\varepsilon}(\alpha, \alpha) - \frac{1}{2}\mathrm{OT_\varepsilon}(\beta, \beta)\) removes this entropic bias.
In this tutorial we showcase the advantage of removing the entropic bias using gradient flows on 2-D distributions, as done in [Feydy et al., 2019] and following the Point Clouds tutorial.
Imports#
from typing import Any, Callable
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from IPython import display
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot, sinkhorn_divergence
Sampling Source/Target Distributions#
Let us start by defining simple source and target distributions.
key1, key2 = jax.random.split(jax.random.key(0), 2)
x = 0.25 * jax.random.normal(key1, (25, 2)) # Source
y = 0.5 * jax.random.normal(key2, (50, 2)) + jnp.array((6, 0)) # Target
plt.scatter(x[:, 0], x[:, 1], edgecolors="k", marker="o", label="x", s=200)
plt.scatter(y[:, 0], y[:, 1], edgecolors="k", marker="X", label="y", s=200)
plt.legend(fontsize=15)
plt.show()
Gradient Flow of a Divergence#
The code below performs gradient descent to move points in a point cloud x
in a way that minimizes a divergence to another point cloud y
, divergence(x, y, epsilon)
.
def gradient_flow(
x: jnp.ndarray,
y: jnp.ndarray,
divergence: Callable[[jnp.ndarray, jnp.ndarray, float], tuple[float, Any]],
num_iter: int = 500,
lr: float = 0.2,
dump_every: int = 50,
epsilon: float = None,
):
"""Compute an entropic Wasserstein (possibly debiased) gradient flow."""
ots = []
# Apply jax.value_and_grad operator and jit that function.
divergence_vg = jax.jit(jax.value_and_grad(divergence, has_aux=True))
# Perform gradient descent on `x`.
for i in range(0, num_iter + 1):
(cost, ot), grad_x = divergence_vg(x, y, epsilon)
assert ot.converged
x = x - grad_x * lr # Perform a gradient descent step.
if i % dump_every == 0:
ots.append(ot) # Save the current state of the optimization.
return ots
def display_animation(ots, plot_class=plot.Plot):
"""Display an animation of the gradient flow."""
plott = plot_class(show_lines=False)
anim = plott.animate(ots, frame_rate=4)
html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()
Gradient Flow of \(\mathrm{OT}_\varepsilon\)#
We set the divergence
to be the regularized OT cost.
def reg_ot_cost(x, y, epsilon=None):
"""Return the OT cost and OT output given a geometry"""
geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
ot = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom))
return ot.reg_ot_cost, ot
For the default value of \(\varepsilon\), the gradient flow behaves as expected:
# Compute and display the gradient flow for the regularized OT cost.
ots = gradient_flow(x, y, reg_ot_cost)
display_animation(ots)
But for a larger \(\varepsilon\), the distribution collapses:
# Compute and display the gradient flow for a larger value of epsilon.
ots = gradient_flow(x, y, reg_ot_cost, epsilon=1.0)
display_animation(ots)