Point Clouds#
We cover in this tutorial how to solve OT problems between two point clouds by instantiating a PointCloud
geometry.
import sys
if "google.colab" in sys.modules:
!pip install -q git+https://github.com/ott-jax/ott@main
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from IPython import display
import ott
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
Create a PointCloud
#
def create_points(rng: jax.random.PRNGKeyArray, n: int, m: int, d: int):
rngs = jax.random.split(rng, 3)
x = jax.random.normal(rngs[0], (n, d)) + 1
y = jax.random.uniform(rngs[1], (m, d))
return x, y
rng = jax.random.PRNGKey(0)
n, m, d = 13, 17, 2
x, y = create_points(rng, n=n, m=m, d=d)
geom = pointcloud.PointCloud(x, y)
Compute the regularized optimal transport#
To compute the transport matrix between the two point clouds, one defines first a PointCloud
geometry.
A PointCloud
geometry holds two arrays of vectors (supporting the two measures of interest), along with a cost function (a CostFn
object, set by default to SqEuclidean
) and, possibly an Epsilon
regularization parameter.
This geometry object defines a LinearProblem
object, which contains all the data needed to instantiate a linear OT problem (see the Gromov-Wasserstein tutorial for quadratic OT problems).
We can then call a Sinkhorn
solver to solve that problem, and compute the OT between these points clouds. Note that all weights are assumed to be uniform in this notebook, but non-uniform weights can be passed as a=..., b=...
arguments when defining the LinearProblem
below.
# Define a linear problem with that cost structure.
ot_prob = linear_problem.LinearProblem(geom)
# Create a Sinkhorn solver
solver = sinkhorn.Sinkhorn()
# Solve OT problem
ot = solver(ot_prob)
# The out object contains many things, among which the regularized OT cost
print(
" Sinkhorn has converged: ",
ot.converged,
"\n",
"Error upon last iteration: ",
ot.errors[(ot.errors > -1)][-1],
"\n",
"Sinkhorn required ",
jnp.sum(ot.errors > -1),
" iterations to converge. \n",
"Entropy regularized OT cost: ",
ot.reg_ot_cost,
"\n",
"OT cost (without entropy): ",
jnp.sum(ot.matrix * ot.geom.cost_matrix),
)
Sinkhorn has converged: True
Error upon last iteration: 0.00068787485
Sinkhorn required 6 iterations to converge.
Entropy regularized OT cost: 1.4432423
OT cost (without entropy): 1.2848572
The ot
object contains several functions and properties, notably a simple way to instantiate, if needed, the OT matrix.
# you can instantiate the OT matrix
P = ot.matrix
plt.imshow(P, cmap="Purples")
plt.colorbar();

You can also instantiate a plott
object to help visualize the transport in 2-dimensions.
plott = ott.tools.plot.Plot()
_ = plott(ot)

OT gradient flows#
OTT returns quantities that are differentiable. In the following example, we leverage the gradients to move \(n\) points in a way that minimizes the overall regularized OT cost, given a ground cost function.
We start by defining a minimal optimization loop, that does fixed-length gradient descent, and records various ot
objects along the way for plotting. By choosing various cost functions, we can then plot different types of gradient flows for the point cloud in \(x\). See also the Sinkhorn divergence gradient flows tutorial.
def optimize(
x: jnp.ndarray,
y: jnp.ndarray,
num_iter: int = 300,
dump_every: int = 5,
learning_rate: float = 0.2,
**kwargs, # passed to the pointcloud.PointCloud geometry
):
# Wrapper function that returns OT cost and OT output given a geometry.
def reg_ot_cost(geom):
out = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom))
return out.reg_ot_cost, out
# Apply jax.value_and_grad operator. Note that we make explicit that
# we only wish to compute gradients w.r.t the first output,
# using the has_aux flag. We also jit that function.
reg_ot_cost_vg = jax.jit(jax.value_and_grad(reg_ot_cost, has_aux=True))
# Run a naive, fixed stepsize, gradient descent on locations `x`.
ots = []
for i in range(0, num_iter + 1):
geom = pointcloud.PointCloud(x, y, **kwargs)
(reg_ot_cost, ot), geom_g = reg_ot_cost_vg(geom)
assert ot.converged
x = x - geom_g.x * learning_rate
if i % dump_every == 0:
ots.append(ot)
return ots
# Helper function to plot successively the optimal transports
def plot_ots(ots):
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
anim = plott.animate(ots, frame_rate=4)
html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()
\(W_2^2\) Gradient Flow
plot_ots(
optimize(
x,
y,
num_iter=100,
epsilon=1e-2,
cost_fn=costs.SqEuclidean(),
)
)