Getting Started#

This short tutorial covers a basic use case for ott:

  • Compute a optimal transport distance between two point clouds using the PointCloud geometry, solved using the Sinkhorn algorithm.

  • Showcase the seamless integration with JAX, to differentiate through that cost and plot the gradient flow to morph the first point cloud into the second.

Imports and toy data definition#

import sys

if "google.colab" in sys.modules:
    %pip install -q git+https://github.com/ott-jax/ott@main
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn

ott is built on top of JAX, so we use JAX to instantiate all variables. We generate two 2-dimensional random point clouds of \(7\) and \(11\) points, respectively, and store them in variables x and y:

rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5
x_old = x

Because these point clouds are 2-dimensional, we can use scatter plots to illustrate them.

x_args = {
    "s": 80,
    "label": r"source $x$",
    "marker": "s",
    "edgecolor": "k",
    "alpha": 0.75,
}
y_args = {"s": 80, "label": r"target $y$", "edgecolor": "k", "alpha": 0.75}
plt.figure(figsize=(9, 6))
plt.scatter(x[:, 0], x[:, 1], **x_args)
plt.scatter(y[:, 0], y[:, 1], **y_args)
plt.legend()
plt.show()
../../_images/d1295200da413df040e0c3cb2b8ec59fd6d46f71e99f5c5b2d51b116822c3e84.png

Optimal transport with ott#

We will now use ott to compute the optimal transport between x and y. To do so, we first create a geom object that stores the geometry (a.k.a. the ground cost) between x and y:

geom = pointcloud.PointCloud(x, y)

geom contains the two datasets x and y, as well as a cost_fn that is a way to measure distances between points. Here, we use the default settings, so the cost_fn is SqEuclidean, the usual squared-Euclidean distance.

In order to compute the optimal transport corresponding to geom, we use the Sinkhorn algorithm. The Sinkhorn algorithm has a regularization hyperparameter epsilon. ott stores that parameter in geom, and uses by default the twentieth of the mean cost between all points in x and y. While it is also possible to set probably weights a for each point in x (and b for y), these are uniform by default.

solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom, a=None, b=None)

As a small note: the computations here are jitted, meaning that the second time the solver is run it will be much faster:

ot = solve_fn(geom)

The output object ot contains the solution of the optimal transport problem. This includes the optimal coupling matrix, that indicates at entry [i, j] how much of the mass of the point x[i] is moved towards y[j].

plt.figure(figsize=(10, 6))
plt.imshow(ot.matrix)
plt.colorbar()
plt.title("Optimal Coupling Matrix")
plt.show()
../../_images/2240aabc42b1b9e990d6836a8aabc62fbf0aa00c826e7a26ab6244c359bf8db6.png

ot stores many more things, notably a lower, as well as an upper bound of the “true” squared 2-Wasserstein metric between x and y (the gap between these two bounds can be made arbitrarily small as epsilon decreases, when geom is instantiated).

print(
    f"2-Wasserstein: Lower bound = {ot.dual_cost:3f}, upper = {ot.primal_cost:3f}"
)
2-Wasserstein: Lower bound = 0.596925, upper = 1.038265

Automatic differentiation using JAX#

We finish this quick tour by illustrating one of the main features of ott: it can be seamlessly integrated into differentiable, end-to-end architectures built using JAX (see also Sinkhorn Divergence Hessians) for an example exploiting implicit differentiation).

We provide a simple use-case where we differentiate the (regularized) OT transport cost w.r.t. x, by defining a function that takes x and y as input, to output their regularized OT cost.

def reg_ot_cost(x: jnp.ndarray, y: jnp.ndarray) -> float:
    geom = pointcloud.PointCloud(x, y)
    ot = sinkhorn.solve(geom)
    return ot.reg_ot_cost

Obtaining the gradient function of reg_ot_cost is as easy as making a call to jax.grad() on reg_ot_cost, e.g. jax.grad(reg_ot_cost).

We use jax.value_and_grad() below to also store the value of the output itself. Note that by default, JAX only computes the gradient w.r.t the first of variable of reg_ot_cost , here x.

# value and gradient *function*
r_ot = jax.value_and_grad(reg_ot_cost)
# evaluate it at `(x, y)`.
cost, grad_x = r_ot(x, y)
assert grad_x.shape == x.shape

grad_x is a matrix that has the same size as x. Updating x with the opposite of that gradient decreases the loss. This process can done iteratively, following a gradient flow, to push x closer to y.

step = 2.0
x = x_old
quiv_args = {"scale": 1, "angles": "xy", "scale_units": "xy", "width": 0.005}
f, axes = plt.subplots(1, 3, sharey=True, sharex=True, figsize=(12, 4))

for iteration, ax in enumerate(axes):
    cost, grad_x = r_ot(x, y)
    ax.scatter(x[:, 0], x[:, 1], **x_args)
    ax.quiver(
        x[:, 0],
        x[:, 1],
        -step * grad_x[:, 0],
        -step * grad_x[:, 1],
        **quiv_args,
    )
    ax.scatter(y[:, 0], y[:, 1], **y_args)
    ax.set_title(f"iter: {iteration}, cost: {cost:.3f}")
    x -= step * grad_x
../../_images/07b830d38c71c0d11717d231174544f2b5a9224056eb68f5fd8b433cb1239211.png

Going further#

This tutorial gave you a glimpse of the most basic features of ott and how they integrate with JAX. ott implements many other functionalities, including improved execution and extensions of the basic optimal transport problem such as: