Getting Started#
This short tutorial covers a basic use case for ott
:
Compute a optimal transport distance between two point clouds using the
PointCloud
geometry, solved using theSinkhorn
algorithm.Showcase the seamless integration with
JAX
, to differentiate through that cost and plot the gradient flow to morph the first point cloud into the second.
Imports and toy data definition#
import sys
if "google.colab" in sys.modules:
%pip install -q git+https://github.com/ott-jax/ott@main
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn
ott
is built on top of JAX
, so we use JAX
to instantiate all variables. We generate two 2-dimensional random point clouds of \(7\) and \(11\) points, respectively, and store them in variables x
and y
:
rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5
x_old = x
Because these point clouds are 2-dimensional, we can use scatter plots to illustrate them.
x_args = {
"s": 80,
"label": r"source $x$",
"marker": "s",
"edgecolor": "k",
"alpha": 0.75,
}
y_args = {"s": 80, "label": r"target $y$", "edgecolor": "k", "alpha": 0.75}
plt.figure(figsize=(9, 6))
plt.scatter(x[:, 0], x[:, 1], **x_args)
plt.scatter(y[:, 0], y[:, 1], **y_args)
plt.legend()
plt.show()

Optimal transport with ott
#
We will now use ott
to compute the optimal transport between x
and y
. To do so, we first create a geom
object that stores the geometry (a.k.a. the ground cost) between x
and y
:
geom = pointcloud.PointCloud(x, y)
geom
contains the two datasets x
and y
, as well as a cost_fn
that is a way to measure distances between points. Here, we use the default settings, so the cost_fn
is SqEuclidean
, the usual squared-Euclidean distance.
In order to compute the optimal transport corresponding to geom
, we use the Sinkhorn
algorithm. The Sinkhorn algorithm has a regularization hyperparameter epsilon
. ott
stores that parameter in geom
, and uses by default the twentieth of the mean cost between all points in x
and y
. While it is also possible to set probably weights a
for each point in x
(and b
for y
), these are uniform by default.
solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom, a=None, b=None)
As a small note: the computations here are jitted
, meaning that the second time the solver is run it will be much faster:
ot = solve_fn(geom)
The output object ot
contains the solution of the optimal transport problem. This includes the optimal coupling matrix, that indicates at entry [i, j]
how much of the mass of the point x[i]
is moved towards y[j]
.
plt.figure(figsize=(10, 6))
plt.imshow(ot.matrix)
plt.colorbar()
plt.title("Optimal Coupling Matrix")
plt.show()

ot
stores many more things, notably a lower, as well as an upper bound of the “true” squared 2-Wasserstein metric between x
and y
(the gap between these two bounds can be made arbitrarily small as epsilon
decreases, when geom
is instantiated).
print(
f"2-Wasserstein: Lower bound = {ot.dual_cost:3f}, upper = {ot.primal_cost:3f}"
)
2-Wasserstein: Lower bound = 0.596925, upper = 1.038265
Automatic differentiation using JAX
#
We finish this quick tour by illustrating one of the main features of ott
: it can be seamlessly integrated into differentiable, end-to-end architectures built using JAX
(see also Sinkhorn Divergence Hessians) for an example exploiting implicit differentiation).
We provide a simple use-case where we differentiate the (regularized) OT transport cost w.r.t. x
,
by defining a function that takes x
and y
as input, to output their regularized OT cost.
def reg_ot_cost(x: jnp.ndarray, y: jnp.ndarray) -> float:
geom = pointcloud.PointCloud(x, y)
ot = sinkhorn.solve(geom)
return ot.reg_ot_cost
Obtaining the gradient function of reg_ot_cost
is as easy as making a call to jax.grad()
on reg_ot_cost
, e.g. jax.grad(reg_ot_cost)
.
We use jax.value_and_grad()
below to also store the value of the output itself. Note that by default, JAX
only computes the gradient w.r.t the first of variable of reg_ot_cost
, here x
.
# value and gradient *function*
r_ot = jax.value_and_grad(reg_ot_cost)
# evaluate it at `(x, y)`.
cost, grad_x = r_ot(x, y)
assert grad_x.shape == x.shape
grad_x
is a matrix that has the same size as x
. Updating x
with the opposite of that gradient decreases the loss. This process can done iteratively, following a gradient flow, to push x
closer to y
.
step = 2.0
x = x_old
quiv_args = {"scale": 1, "angles": "xy", "scale_units": "xy", "width": 0.005}
f, axes = plt.subplots(1, 3, sharey=True, sharex=True, figsize=(12, 4))
for iteration, ax in enumerate(axes):
cost, grad_x = r_ot(x, y)
ax.scatter(x[:, 0], x[:, 1], **x_args)
ax.quiver(
x[:, 0],
x[:, 1],
-step * grad_x[:, 0],
-step * grad_x[:, 1],
**quiv_args,
)
ax.scatter(y[:, 0], y[:, 1], **y_args)
ax.set_title(f"iter: {iteration}, cost: {cost:.3f}")
x -= step * grad_x

Going further#
This tutorial gave you a glimpse of the most basic features of ott
and how they integrate with JAX
.
ott
implements many other functionalities, including improved execution and extensions of the basic optimal transport problem such as:
More general cost functions in Point Clouds,
How to use a progress bar in Tracking progress of ott.solvers,
Regularization and acceleration of
Sinkhorn
solvers in Sinkhorn in All Flavors,Gromov-Wasserstein, to compare distributions defined on incomparable spaces.
Low-rank Sinkhorn for faster solvers that exploit a low-rank factorization of coupling matrices,
Wasserstein barycenters, as in GMM Barycenters or Sinkhorn Barycenters,
Differentiable sorting in Soft Sorting,
Neural solvers in Neural Dual Solver, to estimate maps in functional form.