# Optimal Transport Tools (OTT) documentation

## Contents

# Optimal Transport Tools (OTT) documentation#

Code hosted on Github. To install,
clone that repo or simply run `pip install ott-jax`

.

## Intro#

OTT is a JAX package that bundles a few utilities to compute and differentiate the solution to optimal transport problems. OTT can help you compute Wasserstein distances between weighted clouds of points (or histograms), using a cost (e.g. a distance) between individual points.

To that end OTT uses various implementation of the Sinkhorn algorithm 1 2 3. These implementation take advantage of several JAX features, such as Just-in-time (JIT) compilation, auto-vectorization (VMAP), and both automatic and/or implicit differentiation. A few tutorial snippets are provided below, along with different use-cases, notably for single-cell genomics data 4.

## Packages#

There are currently three packages, `geometry`

, `core`

and `tools`

, playing the following roles:

`geometry`

defines classes that describe*two point clouds*paired with a*cost*function (simpler geometries are also implemented, such as that defined by points supported on a multi-dimensional grids with a separable cost 5). The design choice in OTT is to state that cost functions and algorithms should operate independently: if a particular cost function allows for faster computations (e.g. squared-Euclidean distance when comparing grids), this should not be taken advantage of at the level of optimizers, but at the level of the problems description. Geometry objects are therefore only considered as arguments to describe OT problem handled in`core`

, using subroutines provided by geometries;`core`

help define first an OT problem (linear, quadratic, barycenters). These problems are then solved using Sinkhorn algorithm and its variants, the main workhorse to solve OT in this package, as well as variants that can comppute Gromov-Wasserstein distances or barycenters of several measures;`tools`

provides an interface to exploit OT solutions, as produced by`core`

functions. Such tasks include instantiating OT matrices, computing approximations to Wasserstein distances 6 7, or computing differentiable sort and quantile operations 8.

### Point clouds#

We cover in this tutorial the instantiation and use of a `PointCloud`

geometry.

A `PointCloud`

geometry holds two arrays of vectors, endowed with a cost function. Such a geometry should cover most users’ needs.

We further show differentiation through optimal transport as an example of optimization that leverages first-order gradients.

This Notebook can be run on either Jupyter Notebook or Colab (which requires running `! pip install ott-jax`

first).

```
[1]:
```

```
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import ott
from ott.geometry import pointcloud
from ott.core import sinkhorn
from ott.tools import transport
```

#### Creates a PointCloud geometry#

```
[3]:
```

```
def create_points(rng, n, m, d):
rngs = jax.random.split(rng, 3)
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jnp.ones((n,)) / n
b = jnp.ones((m,)) / m
return x, y, a, b
rng = jax.random.PRNGKey(0)
n, m, d = 12, 14, 2
x, y, a, b = create_points(rng, n=n, m=m, d=d)
```

#### Computes the regularized optimal transport#

To compute the transport matrix between the two point clouds, one can define a `PointCloud`

geometry (which by default uses `ott.geometry.costs.Euclidean`

for cost function), then call the `sinkhorn`

function, and build the transport matrix from the optimized potentials.

```
[8]:
```

```
geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
out = sinkhorn.sinkhorn(geom, a, b)
P = geom.transport_from_potentials(out.f, out.g)
```

A more concise syntax to compute the optimal transport matrix is to use the `transport.solve`

. Note how weights are assumed to be uniform if no parameter `a`

and `b`

is passed to `transport.solve`

.

```
[18]:
```

```
ot = transport.solve(x, y, a=a, b=b, epsilon=1e-2)
```

#### Visualizes the transport#

```
[10]:
```

```
plt.imshow(ot.matrix, cmap='Purples')
plt.colorbar();
```

```
[11]:
```

```
plott = ott.tools.plot.Plot()
_ = plott(ot)
```

#### Differentiation through Optimal Transport#

OTT returns quantities that are differentiable. In the following example, we leverage the gradients to move `N`

points in a way that minimizes the overall regularized OT cost, given a ground cost function, here the squared Euclidean distance.

```
[26]:
```

```
def optimize(x: jnp.ndarray,
y: jnp.ndarray,
a: jnp.ndarray,
b: jnp.ndarray,
cost_fn=ott.geometry.costs.Euclidean(),
num_iter: int = 101,
dump_every: int = 10,
learning_rate: float = 0.2):
reg_ot_cost_vg = jax.value_and_grad(jax.jit(
(lambda geom, a, b: ott.core.sinkhorn.sinkhorn(geom, a, b).reg_ot_cost)),
argnums=0)
ot = transport.solve(x, y, a=a, b=b, cost_fn=cost_fn, epsilon=1e-2, jit=True)
result = [ot]
for i in range(1, num_iter + 1):
reg_ot_cost, geom_g = reg_ot_cost_vg(ot.geom, ot.a, ot.b)
x = x - geom_g.x * learning_rate
ot = transport.solve(x, y, a=a, b=b, cost_fn=cost_fn, epsilon=1e-2, jit=True)
if i % dump_every == 0:
result.append(ot)
return result
```

```
[25]:
```

```
from IPython import display
ots = optimize(x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.Euclidean())
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
anim = plott.animate(ots, frame_rate=4)
html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()
```

We could use another cost function, in this case Cosine distance, to achieve another kind of dynamics in optimization.

```
[27]:
```

```
ots = optimize(x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.Cosine())
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
anim = plott.animate(ots, frame_rate=8)
html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()
```

### Grid geometry#

In this tutorial, we cover how to instantiate and use `Grid`

.

`Grid`

is a geometry that is useful when the probability measures are supported on a \(d\)-dimensional cartesian grid, i.e. a cartesian product of \(d\) lists of values, each list \(i\) being of size \(n_i\). The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in \(\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)\) where
\(\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\).

The advantage of using `Grid`

over `PointCloud`

for such cases is that the computational cost is \(O(N^{(1+1/d)})\) instead of \(O(N^2)\) where \(N\) is the total number of points in the grid.

```
[1]:
```

```
import jax
import jax.numpy as jnp
import numpy as np
from ott.core import sinkhorn
from ott.geometry import costs
from ott.geometry import grid
from ott.geometry import pointcloud
```

#### Uses `Grid`

with the argument `x`

#

In this example, the argument `x`

is a list of \(3\) vectors, of varying sizes \(\{n_1, n_2, n_3\}\), that describe the locations of the grid. The resulting grid is the Cartesian product of these vectors. `a`

and `b`

are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube.

```
[2]:
```

```
rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 5)
grid_size = (5, 6, 7)
x = [jax.random.uniform(keys[0], (grid_size[0],)),
jax.random.uniform(keys[1], (grid_size[1],)),
jax.random.uniform(keys[2], (grid_size[2],))]
a = jax.random.uniform(keys[3], grid_size)
b = jax.random.uniform(keys[4], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
```

Instantiates `Grid`

and calculates the regularized optimal transport cost.

```
[3]:
```

```
geom = grid.Grid(x=x, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
```

```
Regularised optimal transport cost = 0.30520981550216675
```

#### Uses `Grid`

with the argument `grid_size`

#

In this example, the grid is described as points regularly sampled in \([0, 1]\). `a`

and `b`

are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube \([0, 1]^3\).

```
[4]:
```

```
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6, 7)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
```

Instantiates `Grid`

and calculates the regularized optimal transport cost.

```
[5]:
```

```
geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
```

```
Regularised optimal transport cost = 0.3816334307193756
```

#### Varies the cost function in each dimension#

Instead of the squared Euclidean distance, we will use a squared Mahalanobis distance, where the covariance matrix is diagonal. This example illustrates the possibility of choosing a cost function for each dimension.

```
[6]:
```

```
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
```

We want to use as covariance matrix for the Mahalanobis distance the diagonal 2x2 matrix, with \([1/2, 1]\) as diagonal. We create an additional costs.CostFn.

```
[7]:
```

```
@jax.tree_util.register_pytree_node_class
class EuclideanTimes2(costs.CostFn):
"""The cost function corresponding to the squared euclidean distance times 2."""
def norm(self, x):
return jnp.sum(x ** 2, axis=-1) * 2
def pairwise(self, x, y):
return - 2 * jnp.sum(x * y) * 2
cost_fns = [EuclideanTimes2(), costs.Euclidean()]
```

Instantiates `Grid`

and calculates the regularized optimal transport cost.

```
[8]:
```

```
geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
```

```
Regularised optimal transport cost = 0.3241420388221741
```

#### Compares runtime between using `Grid`

and `PointCloud`

#

The squared euclidean distance is an example of separable distance for which it is possible to use `Grid`

instead of `PointCloud`

. In this case, using `Grid`

over `PointCloud`

as geometry in the context of regularised optimal transport presents a computational advantage, as the computational cost of applying a kernel in Sinkhorn steps is of the order of \(O(N^{(1+1/d)})\) instead of the naive \(O(N^2)\) complexity, where \(N\) is the total number of points in the grid and
\(d\) its dimension. In this example, we can see that for the same grid size and points, the computational runtime of sinkhorn with `Grid`

is smaller than with `PointCloud`

.

```
[9]:
```

```
epsilon = 0.1
grid_size = (50, 50, 50)
rng = jax.random.PRNGKey(2)
keys = jax.random.split(rng, 2)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
xyz = jnp.stack([
jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
]).transpose()
# Instantiates PointCloud with argument 'online=True'
geometry_pointcloud = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon, online=True)
# Runs on GPU
%timeit sinkhorn.sinkhorn(geometry_grid, a=a, b=b).reg_ot_cost.block_until_ready()
out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b)
print(f'Regularised optimal transport cost using Grid = {out_grid.reg_ot_cost}\n')
%timeit sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b).reg_ot_cost.block_until_ready()
out_pointcloud = sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b)
print(f'Regularised optimal transport cost using Pointcloud = {out_pointcloud.reg_ot_cost}')
```

```
1 loops, best of 3: 35.5 ms per loop
Regularised optimal transport cost using Grid = 0.34500643610954285
1 loops, best of 3: 11.4 s per loop
Regularised optimal transport cost using PointCloud = 0.34500643610954285
```

### OTT vs. POT#

The Python Optimal Transport (POT) toolbox paved the way for much progress in OT. `POT`

implements several OT solvers (LP and regularized), and is complemented with various tools (barycenters, domain adaptation, Gromov-Wasserstein distances, sliced W, etc.). The coverage of `OTT`

is currently far smaller than of `POT`

.

With that disclaimer in mind, the goal of this notebook is to compare the performance of their Sinkhorn solvers. `OTT`

benefits from just-in-time compilation, which should give it an edge. `OTT`

is also differentiable w.r.t its inputs, but since `POT`

is not that aspect is not considered here.

The comparisons carried out below have limitations: minor modifications in the setup (e.g. data distributions, tolerance thresholds, type of accelerator…) could have an impact on these results. Feel free to change these settings and experiment by yourself!

This NB was run on colab using a GPU.

#### Installs toolboxes#

We install the 2 toolboxes first…

```
[ ]:
```

```
!pip install ott-jax
!pip install POT
```

… and import them, along with their numerical environments, `jax`

and `numpy`

.

```
[3]:
```

```
# import JAX and OTT
import jax
import jax.numpy as jnp
import ott
from ott.geometry import pointcloud
from ott.core import sinkhorn
# import OT, from POT
import numpy as np
import ot
# misc
import matplotlib.pyplot as plt
plt.rc('font', size = 20)
import mpl_toolkits.axes_grid1
import timeit
```

#### Regularized OT in a nutshell#

We consider two probability measures \(\mu,\nu\) compared with the squared-Euclidean distance, \(c(x,y)=\|x-y\|^2\). These measures are discrete and of the same size in this notebook:

to define the OT problem in its primal form,

where \(U(a,b):=\{P \in \mathbf{R}_+^{n\times n}, P\mathbf{1}_{n}=b, P^T\mathbf{1}_n=b\}\), and \(C = [ \|x_i - y_j \|^2 ]_{i,j}\in \mathbf{R}_+^{n\times n}\).

That problem is equivalent to the following dual form,

These two problems are solved by `OTT`

and `POT`

using the *Sinkhorn iterations* using a simple initialization for \(u\), and subsequent updates \(v \leftarrow a / K^Tu, u \leftarrow b / Kv\), where \(K:=e^{-C/\varepsilon}\).

Upon convergence to fixed points \(u^*, v^*\), one has

or, alternatively,

#### OTT and POT implementation#

Both toolboxes carry out Sinkhorn updates using either the formulas above directly (this corresponds to `lse_mode=False`

in `OTT`

and `method=sinkhorn`

in `POT`

) or using slightly slower but more robust approaches:

`OTT`

relies on log-space iterations (`lse_mode=True`

), whereas `POT`

, uses a stabilization trick , using the `method=sinkhorn_stabilized`

flag, designed to avoid numerical overflows, while still benefitting from the speed given by matrix vector products.

The default behaviour of `OTT`

and POT is to carry out these updates until \(\|u\circ Kv - a\|_2 + \|v\circ K^Tu - b\|_2\) is smaller than the user-defined `threshold`

.

#### Common API for `OTT`

and `POT`

#

We will compare in our experiments `OTT`

vs. `POT`

in their more stable setups (`lse_mode`

and `stabilized`

). We define a common API for both, making sure their results are comparable. That API takes as inputs the measures’ info, the targeted 𝜀 value and the `threshold`

used to terminate the algorithm. We set a maximum of 1000 iterations for both.

```
[6]:
```

```
def solve_ot(a, b, x, y, 𝜀, threshold):
_, log = ot.sinkhorn(a, b, ot.dist(x,y), 𝜀, stopThr=threshold,
method='sinkhorn_stabilized', log=True,
numItermax=1000)
f, g = 𝜀 * log['logu'], 𝜀 * log['logv']
f, g = f - np.mean(f), g + np.mean(f) # center variables, useful if one wants to compare them
reg_ot = np.sum(f * a) + np.sum(g * b) if log['err'][-1] < threshold else np.nan
return f, g, reg_ot
@jax.jit
def solve_ott(a, b, x, y, 𝜀, threshold):
out = sinkhorn.sinkhorn(pointcloud.PointCloud(x, y, epsilon=𝜀),
a, b, threshold=threshold, lse_mode=True, jit=False,
max_iterations=1000)
f, g = out.f, out.g
f, g = f - np.mean(f), g + np.mean(f) # center variables, useful if one wants to compare them
reg_ot = jnp.where(out.converged, jnp.sum(f * a) + jnp.sum(g * b), jnp.nan)
return f, g, reg_ot
```

To test both solvers, we run simulations using a random seed to generate random point clouds of size \(n\). Random generation is carried out using jax.random, to ensure reproducibility. A solver provides three pieces of info: the function (using our simple common API), its numerical environment and its name.

```
[7]:
```

```
dim = 3
def run_simulation(rng, n, 𝜀, threshold, solver_spec):
# setting global variables helps avoir a timeit bug.
global solver_
global a, b, x , y
# extract specificities of solver.
solver_, env, name = solver_spec
# draw data at random using JAX
rng, *rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (n, dim)) + 0.1
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (n,))
a = a / jnp.sum(a)
b = b / jnp.sum(b)
# map to numpy if needed
if env == 'np':
a, b, x, y = map(np.array,(a, b, x, y))
timeit_res = %timeit -o solver_(a, b, x, y, 𝜀, threshold)
out = solver_(a, b, x, y, 𝜀, threshold)
exec_time = np.nan if np.isnan(out[-1]) else timeit_res.best
return exec_time, out
```

Defines the two solvers used in this experiment:

```
[8]:
```

```
POT = (solve_ot, 'np', 'POT')
OTT = (solve_ott, 'jax', 'OTT')
```

#### Runs simulations with varying \(n\) and \(\varepsilon\)#

We run simulations by setting the regularization strength 𝜀 to either \(10^{-2}\) or \(10^{-1}\).

We consider \(n\) between sizes \(2^{8}= 256\) and \(2^{12}= 4096\). We do not go higher, because `POT`

runs into out-of-memory errors for \(2^{13}=8192\) in this RAM restricted colab environment. `OTT`

can avoid these by setting the flag `online`

to `True`

, as done in the tutorial for grids, and also handled by the GeomLoss toolbox. We leave the comparison with `geomloss`

to a different NB.

When `%timeit`

outputs execution time, **notice the warning message** highlighting the fact that, for `OTT`

, at least one run took significantly longer. That run is that doing the **JIT pre-compilation** of the procedure, suitable for that particular problem size \(n\). Once pre-compiled, subsequent runs are order of magnitudes faster, thanks to the `@jax.jit`

decorator added to `solve_ott`

.

```
[6]:
```

```
rng = jax.random.PRNGKey(0)
solvers = (POT, OTT)
n_range = 2 ** np.arange(8, 13)
𝜀_range = 10 ** np.arange(-2.0, 0.0)
threshold = 1e-2
exec_time = {}
reg_ot = {}
for solver_spec in solvers:
solver, env, name = solver_spec
print('----- ', name)
exec_time[name] = np.ones((len(n_range), len(𝜀_range))) * np.nan
reg_ot[name] = np.ones((len(n_range), len(𝜀_range))) * np.nan
for i, n in enumerate(n_range):
for j, 𝜀 in enumerate(𝜀_range):
exec, out = run_simulation(rng, n, 𝜀, threshold, solver_spec)
exec_time[name][i, j] = exec
reg_ot[name][i, j] = out[-1]
```

```
----- POT
10 loops, best of 5: 43.7 ms per loop
100 loops, best of 5: 11.9 ms per loop
1 loop, best of 5: 230 ms per loop
10 loops, best of 5: 41.4 ms per loop
1 loop, best of 5: 33.4 s per loop
10 loops, best of 5: 155 ms per loop
1 loop, best of 5: 2min 13s per loop
1 loop, best of 5: 367 ms per loop
1 loop, best of 5: 6min 21s per loop
1 loop, best of 5: 1.22 s per loop
----- OTT
The slowest run took 66.78 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 11.2 ms per loop
1000 loops, best of 5: 1.04 ms per loop
The slowest run took 128.37 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 6.12 ms per loop
1000 loops, best of 5: 1.08 ms per loop
The slowest run took 94.84 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 8.95 ms per loop
1000 loops, best of 5: 1.42 ms per loop
The slowest run took 33.90 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 24 ms per loop
100 loops, best of 5: 3.47 ms per loop
The slowest run took 8.19 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 112 ms per loop
100 loops, best of 5: 14.3 ms per loop
```

#### Plots results in terms of time and difference in objective.#

When the algorithm does not converge within the maximal number of 1000 iterations, or runs into numerical issues, the solver returns a NaN and that point does not appear in the plot.

```
[24]:
```

```
list_legend = []
fig = plt.figure(figsize=(14,8))
for solver_spec, marker, col in zip(solvers,('p','o'), ('blue','red')):
solver, env, name = solver_spec
p = plt.plot(exec_time[name], marker=marker, color=col,
markersize=16, markeredgecolor='k', lw=3)
p[0].set_linestyle('dotted')
p[1].set_linestyle('solid')
list_legend += [name + r' $\varepsilon $=' + "{:.2g}".format(𝜀) for 𝜀 in 𝜀_range]
plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.legend(list_legend)
plt.yscale('log')
plt.xlabel('dimension $n$')
plt.ylabel('time (s)')
plt.title(r'Execution Time vs Dimension for OTT and POT for two $\varepsilon$ values')
plt.show()
```

For good measure, we also show the differences in *objectives* between the two solvers. We substract the objective returned by `POT`

to that returned by `OTT`

.

Since the problem is evaluated in its dual form, a *higher* objective is *better*, and therefore a positive difference denotes better performance for `OTT`

. White areas stand for values for which `POT`

did not converge (either because it has exhausted the maximal number of iterations or experienced numerical issues).

```
[21]:
```

```
fig = plt.figure(figsize=(12,8))
ax = plt.gca()
im = ax.imshow(reg_ot['OTT'].T - reg_ot['POT'].T)
plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.yticks(ticks=np.arange(len(𝜀_range)), labels=𝜀_range)
plt.xlabel('dimension $n$')
plt.ylabel(r'regularization $\varepsilon$')
plt.title('Gap in objective, >0 when OTT is better')
divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
plt.colorbar(im, cax=cax)
plt.show()
```

```
[45]:
```

```
for name in ('POT','OTT'):
print('----', name)
print('Objective')
print(reg_ot[name])
print('Execution Time')
print(exec_time[name])
```

```
---- POT
Objective
[[-0.00862313 -0.79116929]
[-0.02666368 -0.93283839]
[ nan -1.07958862]
[ nan -1.22432204]
[ nan -1.36762311]]
Time
[[0.04367424 0.01185102]
[0.22960342 0.04137421]
[ nan 0.15465033]
[ nan 0.3669143 ]
[ nan 1.21968372]]
---- OTT
Objective
[[-0.00783848 -0.79117149]
[-0.02610656 -0.93283963]
[-0.05083928 -1.07959068]
[-0.06328616 -1.21402502]
[-0.07956241 -1.35710597]]
Time
[[0.01124264 0.00103751]
[0.00612156 0.00107929]
[0.00895449 0.00142238]
[0.02404206 0.00346715]
[0.11208566 0.01432985]]
```

### Sinkhorn in all flavors#

We provide in this example a detailed walk-through some of the functionalities of the `sinkhorn`

algorithm, including the computation of `sinkhorn_divergence`

. This colab has all you need to recreate all plots.

```
[ ]:
```

```
import ott
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from ott.core.sinkhorn import sinkhorn
from ott.geometry.pointcloud import PointCloud
from ott.geometry.geometry import Geometry
```

#### From Texts to Word Histograms#

We adapt a keras NLP tutorial to preprocess raw text (here a subset of texts from the newsgroup20 database) and turn them into word embeddings histograms. See colab for detailed pre-processing.

This helps us recover \(635\) histograms supported on \(4000\) words, each represented by a vector in dimension 50.

```
[ ]:
```

```
# X contains 4000 word embeddings in dimension 50 , HIST a 653 x 4000 (row-normalized) matrix of histograms.
print(f'{HIST.shape[0]} texts supported on up to {HIST.shape[1]} words of dimension {X.shape[1]}')
```

```
653 texts supported on up to 4000 words of dimension 50
```

#### Pairwise Sinkhorn Divergences#

Before setting a value for `epsilon`

, let’s get a feel of what the point cloud of embeddings looks like in terms of distances.

```
[ ]:
```

```
geom = PointCloud(X)
print('median:', geom.median_cost_matrix, ' mean:', geom.mean_cost_matrix, ' max:', jnp.max(geom.cost_matrix))
```

```
median: 0.40351653 mean: 0.41272432 max: 1.4388262
```

Store \(4000 \times 4000\) cost matrix once and for all.

```
[ ]:
```

```
cost = geom.cost_matrix
```

We now define a `jitted`

version of the `sinkhorn_divergence`

using a double `vmap`

to compute in one go the pairwise *matrix* of sinkhorn divergence between two sets of histograms. Jitting is super important to achieve efficiency, don’t forget to wrap whatever you to with a `jax.jit`

if you want to run at scale. Note also how we set a higher convergence `threshold`

(default would be `1e-3`

) to ensure slightly faster execution. Finally, the triple instantiation of `cost`

is not a
bug: it just describes that the computation of each Sinkhorn divergence will trigger the computation of three regularized OT problems using `sinkhorn`

. In principle, each could have its own cost matrix. In this case they are all shared since the histograms are supported on the same set of words by design.

```
[ ]:
```

```
sink_div = jax.jit(jax.vmap(
lambda HIST_1, HIST_2, cost, epsilon: jax.vmap(
lambda hist_1, hist_2, cost, epsilon : sinkhorn_divergence(
Geometry, cost, cost, cost, epsilon=epsilon, a=hist_1, b=hist_2,
sinkhorn_kwargs={'threshold':1e-2}).divergence,
in_axes=[0, None, None, None])(HIST_1, HIST_2, cost, epsilon),
in_axes=[None, 0, None,None]))
```

When setting `epsilon`

to `None`

, the algorithms will default to 1/20th of the mean distance described in the geometry. This is no magical number, but rather a simple statistic of the scale ofthe problem. We recommend that you tune `epsilon`

by yourself, but using `None`

might avoid common issues (such as running `sinkhorn`

with a very small `epsilon`

while the cost matrices are large).

```
[ ]:
```

```
print('Default epsilon is: ', geom.epsilon)
```

```
Default epsilon is: 0.020636216
```

Compute now a pairwise 30 x 30 matrix of sinkhorn divergences (about 1000 divergences in total). We pick 30 different texts twice.

```
[ ]:
```

```
HIST_a = jnp.array(HIST[0:30])
HIST_b = jnp.array(HIST[-30:])
print(HIST_a.shape, HIST_b.shape, cost.shape)
```

```
(30, 4000) (30, 4000) (4000, 4000)
```

Dry run with large epsilon value to force jit compilation before computing timings. This only makes sense within this tutorial.

```
[ ]:
```

```
DIV = sink_div(HIST_a, HIST_b, cost, 10000)
```

We now carry out divergence computations and plot their matrix for various `epsilon`

.

```
[ ]:
```

```
DIV, ran_in = [] , []
epsilons = [None, 5e-2, 1e-1]
for epsilon in epsilons:
tic = time.perf_counter()
DIV.append(sink_div(HIST_a, HIST_b, cost, epsilon).block_until_ready())
toc = time.perf_counter()
ran_in.append(toc - tic)
```

Notice how smaller `epsilon`

has a huge impact on time (far longer). Larger `epsilon`

values result in less spiked values with, however, a similar relative pattern. As `epsilon`

grows, the `sinkhorn_divergence`

converges to a quantity directly related to the Energy distance / MMD. Times below were obtained with a Tesla T4 card.

```
[ ]:
```

```
fig, axes = plt.subplots(1, 3, figsize=(12, 6))
fig.tight_layout()
axes= [axes[0], axes[1], axes[2]]
vmin = min([jnp.min(div) for div in DIV])
vmax = max([jnp.max(div) for div in DIV])
for epsilon, DIV_, ran_in_, ax_ in zip(epsilons, DIV, ran_in, axes):
im = ax_.imshow(DIV_, vmin=vmin, vmax=vmax)
eps = f' ({geom.epsilon:.4f})' if epsilon is None else ''
ax_.set_title(r'$\varepsilon$ = ' + str(epsilon) + eps + f'\n {ran_in_:.2f} s')
ax_.axis('off')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
plt.show()
```

#### The impact of \(\varepsilon\) on convergence#

We study in more detail how `epsilon`

impacts the convergence of the algorithm. We now restrict our attention to the `sinkhorn`

algorithm (not the divergence). We define first a `my_sinkhorn`

to handle computations of `sinkhorn`

with suitable parameters for this notebook.

```
[ ]:
```

```
import functools
my_sinkorn = functools.partial(sinkhorn,
inner_iterations=1, # recomputing error every iteration for plots.
max_iterations=10000, # more iterations than the default setting to see full curves.
jit=True) # force jit
```

We select now two text histograms. We will aim for texts that are supported on more than 1000 words each.

```
[ ]:
```

```
ind = jnp.argsort(jnp.sum(jnp.array(HIST) > 0, axis=1))
a, b = HIST[ind[-2]], HIST[ind[-1]]
print(f'Histogram a supported on {jnp.sum(a >0)} words, b on {jnp.sum(b >0)} words')
```

```
Histogram a supported on 1121 words, b on 1162 words
```

We start by looking more closely into the time needed for `sinkhorn`

to converge for various `epsilon`

values.

```
[ ]:
```

```
out_eps , leg_eps = [], []
epsilons = [1e-3, .3 * 1e-2, 1e-2, .3 * 1e-1, 1e-1]
ran_in = np.zeros((len(epsilons),))
for i, epsilon in enumerate(epsilons):
tic = time.perf_counter()
out_eps.append(my_sinkorn(
Geometry(cost, epsilon=epsilon), a, b))
toc = time.perf_counter()
ran_in[i] = toc - tic
leg_eps.append(r'$\varepsilon$' + f'= {epsilon}, cost = {out_eps[-1].reg_ot_cost:.2f}')
```

These execution times can then be plotted, to result in the following graph

```
[ ]:
```

```
plt.plot(epsilons, ran_in, marker='s', markersize=10, linewidth=3)
```

We now take a closer look at the actual convergence curves of the error of the `sinkhorn`

algorithm (i.e. marginal error). We introduce a `plot_results`

function to visualize this convergence (See colab).

We can now look more closely into `epsilon`

’s impact. Obviously convergence is slower with smaller regularization: There is a tradeoff between speed and how close to the original LP solution we want to be. In the absence of a strong opinion on how small regularization should, we advise that you start using larger `epsilon`

, since this makes your life substantially easier!

```
[ ]:
```

```
plot_results(out_eps, leg_eps, title=r'Iterations needed to converge for various $\varepsilon$', xlabel='iterations', ylabel='error')
```

#### Speeding up Sinkhorn#

##### Fixed Momentum#

Thibault et al. proposed to use a momentum term to (hopefully) accelerate the convergence of the Sinkhorn algorithm. This is controlled by the `momentum`

parameter when calling `sinkhorn`

. We vary that parameter along with various `epsilon`

regularization strengths. As can be seen below, a `momentum`

parameter larger than 1.0 (also known as extrapolation or overrelaxation) helps, but can also be more unstable.

We first compute baseline curves for three \(\varepsilon\) values:

```
[ ]:
```

```
epsilons= [1e-4, 1e-3, 1e-2]
out_baseline, leg_baseline = [], []
for epsilon in epsilons:
out_baseline.append(my_sinkorn(
Geometry(cost, epsilon=epsilon), a, b))
leg_baseline.append('Baseline')
```

Test now using `momentum`

values lower and larger than 1. Run computations first

```
[ ]:
```

```
out_mom, leg_mom = [] , []
for i, epsilon in enumerate(epsilons):
out_mom.append([out_baseline[i]]) # initialize with baseline
leg_mom.append([leg_baseline[i]]) # initialize with baseline
for mom in [.8, 1.05, 1.1, 1.3]:
out_mom[i].append(my_sinkorn(
Geometry(cost, epsilon=epsilon), a, b, momentum=mom))
leg_mom[i].append(f'Momentum : {mom}')
```

Plot them next.

```
[ ]:
```

```
for i, epsilon in enumerate(epsilons):
plot_results(out_mom[i], leg_mom[i], title = r'Fixed Momentum, $\varepsilon$=' + str(epsilon), xlabel='iterations', ylabel='error')
```

You might have noticed in the first set of curves that the values for `momentum`

1.1 and 1.3 are not displayed. For that small \(\varepsilon=0.0001\), the error has diverged from the first update.

```
[ ]:
```

```
[out_mom[0][3].errors], [out_mom[0][4].errors] # Computation diverges from first iteration for small epsilon, high momentum.
```

```
([DeviceArray([inf, -1., -1., ..., -1., -1., -1.], dtype=float32)],
[DeviceArray([inf, -1., -1., ..., -1., -1., -1.], dtype=float32)])
```

##### Adaptive Momentum#

Lehmann et al. propose a simple rule to update the momentum term adaptively, after a few Sinkhorn iterations, by tracking the convergence of the algorithm to compute a momentum parameter. That value is computed after `chg_momentum_from`

iterations. We test this approach with various `epsilon`

values.

```
[ ]:
```

```
out_chg_mom, leg_chg_mom = [], []
for i, epsilon in enumerate(epsilons):
out_chg_mom.append([out_baseline[i]])
leg_chg_mom.append([leg_baseline[i]])
for chg_momentum_from in [10, 20, 50, 200, 1000]:
out_chg_mom[i].append(my_sinkorn(
Geometry(cost, epsilon=epsilon), a, b,
chg_momentum_from=chg_momentum_from))
leg_chg_mom[i].append(f'Change after {chg_momentum_from} it.')
```

As can be seen in the curves below, this seems to be a very effective and robust way to speed up the algorithm.

```
[ ]:
```

```
for i, epsilon in enumerate(epsilons):
plot_results(out_chg_mom[i], leg_chg_mom[i], title = r'Adaptive Momentum, $\varepsilon$=' + str(epsilon), xlabel='iterations', ylabel='error')
```

##### 𝜀 decay#

It also possible to use so called \(\varepsilon\) decay (or scheduling), which consists in starting the Sinkhorn iterations with a large \(\varepsilon\) value that is progressively decreased using a multiplicative update.

```
[ ]:
```

```
out_scaling, leg_scaling = [], []
for i, epsilon in enumerate(epsilons):
out_scaling.append([out_baseline[i]])
leg_scaling.append([leg_baseline[i]])
for decay in [.8,.95]:
for init in [5, 50, 100]:
out_scaling[i].append(
my_sinkorn(Geometry(cost, epsilon=epsilon,
init=init * epsilon, decay=decay), a, b))
leg_scaling[i].append(rf'Decay: {decay}, Init: {init} $\varepsilon$')
```

```
[ ]:
```

```
for i, epsilon in enumerate(epsilons):
plot_results(out_scaling[i], leg_scaling[i], title = rf'Decay, $\varepsilon$=' + str(epsilon), xlabel='iterations', ylabel='error')
```

##### Anderson acceleration#

Using Anderson acceleration on the Sinkhorn algorithm provides mixed results, worsening performance for smaller `epsilon`

regularization, and slightly improving it as the regularization gets larger.

```
[ ]:
```

```
out_anderson, leg_anderson = [], []
for i, epsilon in enumerate(epsilons):
out_anderson.append([out_baseline[i]])
leg_anderson.append([leg_baseline[i]])
for anderson_acceleration in [3, 5, 8, 15]:
out_anderson[i].append(my_sinkorn(Geometry(cost, epsilon=epsilon), a, b,
anderson_acceleration=anderson_acceleration))
leg_anderson[i].append(f'Anderson Acceleration: {anderson_acceleration}')
```

```
[ ]:
```

```
for i, epsilon in enumerate(epsilons):
plot_results(out_anderson[i], leg_anderson[i], title = r'Anderson Acceleration, $\varepsilon$=' + str(epsilon), xlabel='iterations', ylabel='error')
```

##### Decay and momentum#

An interesting direction to accelerate convergence is to update the momentum after the decay schedule has converged.

```
[ ]:
```

```
out_mixed , leg_mixed = [], []
for i, epsilon in enumerate(epsilons):
out_mixed.append([out_baseline[i]])
leg_mixed.append([leg_baseline[i]])
for decay, init, chg_momentum_from in [ [.5, 10, 10],
[.7, 5, 20], [.9, 10, 50],
[.99, 2, 100]]:
out_mixed[i].append(
my_sinkorn(Geometry(cost, epsilon=epsilon,
init=init * epsilon, decay=
```