# OTT vs. POT

## Contents

# OTT vs. POT#

The Python Optimal Transport (POT) toolbox paved the way for much progress in OT. `POT`

implements several OT solvers (LP and regularized), and is complemented with various tools (barycenters, domain adaptation, Gromov-Wasserstein distances, sliced W, etc.). The coverage of `OTT`

is currently far smaller than of `POT`

.

With that disclaimer in mind, the goal of this notebook is to compare the performance of their Sinkhorn solvers. `OTT`

benefits from just-in-time compilation, which should give it an edge. `OTT`

is also differentiable w.r.t its inputs, but since `POT`

is not that aspect is not considered here.

The comparisons carried out below have limitations: minor modifications in the setup (e.g. data distributions, tolerance thresholds, type of acceleratorβ¦) could have an impact on these results. Feel free to change these settings and experiment by yourself!

This NB was run on colab using a GPU.

## Installs toolboxes#

We install the 2 toolboxes firstβ¦

```
[ ]:
```

```
!pip install ott-jax
!pip install POT
```

β¦ and import them, along with their numerical environments, `jax`

and `numpy`

.

```
[3]:
```

```
# import JAX and OTT
import jax
import jax.numpy as jnp
import ott
from ott.geometry import pointcloud
from ott.core import sinkhorn
# import OT, from POT
import numpy as np
import ot
# misc
import matplotlib.pyplot as plt
plt.rc('font', size = 20)
import mpl_toolkits.axes_grid1
import timeit
```

## Regularized OT in a nutshell#

We consider two probability measures \(\mu,\nu\) compared with the squared-Euclidean distance, \(c(x,y)=\|x-y\|^2\). These measures are discrete and of the same size in this notebook:

to define the OT problem in its primal form,

where \(U(a,b):=\{P \in \mathbf{R}_+^{n\times n}, P\mathbf{1}_{n}=b, P^T\mathbf{1}_n=b\}\), and \(C = [ \|x_i - y_j \|^2 ]_{i,j}\in \mathbf{R}_+^{n\times n}\).

That problem is equivalent to the following dual form,

These two problems are solved by `OTT`

and `POT`

using the *Sinkhorn iterations* using a simple initialization for \(u\), and subsequent updates \(v \leftarrow a / K^Tu, u \leftarrow b / Kv\), where \(K:=e^{-C/\varepsilon}\).

Upon convergence to fixed points \(u^*, v^*\), one has

or, alternatively,

## OTT and POT implementation#

Both toolboxes carry out Sinkhorn updates using either the formulas above directly (this corresponds to `lse_mode=False`

in `OTT`

and `method=sinkhorn`

in `POT`

) or using slightly slower but more robust approaches:

`OTT`

relies on log-space iterations (`lse_mode=True`

), whereas `POT`

, uses a stabilization trick , using the `method=sinkhorn_stabilized`

flag, designed to avoid numerical overflows, while still benefitting from the speed given by matrix vector products.

The default behaviour of `OTT`

and POT is to carry out these updates until \(\|u\circ Kv - a\|_2 + \|v\circ K^Tu - b\|_2\) is smaller than the user-defined `threshold`

.

## Common API for `OTT`

and `POT`

#

We will compare in our experiments `OTT`

vs. `POT`

in their more stable setups (`lse_mode`

and `stabilized`

). We define a common API for both, making sure their results are comparable. That API takes as inputs the measuresβ info, the targeted π value and the `threshold`

used to terminate the algorithm. We set a maximum of 1000 iterations for both.

```
[6]:
```

```
def solve_ot(a, b, x, y, π, threshold):
_, log = ot.sinkhorn(a, b, ot.dist(x,y), π, stopThr=threshold,
method='sinkhorn_stabilized', log=True,
numItermax=1000)
f, g = π * log['logu'], π * log['logv']
f, g = f - np.mean(f), g + np.mean(f) # center variables, useful if one wants to compare them
reg_ot = np.sum(f * a) + np.sum(g * b) if log['err'][-1] < threshold else np.nan
return f, g, reg_ot
@jax.jit
def solve_ott(a, b, x, y, π, threshold):
out = sinkhorn.sinkhorn(pointcloud.PointCloud(x, y, epsilon=π),
a, b, threshold=threshold, lse_mode=True, jit=False,
max_iterations=1000)
f, g = out.f, out.g
f, g = f - np.mean(f), g + np.mean(f) # center variables, useful if one wants to compare them
reg_ot = jnp.where(out.converged, jnp.sum(f * a) + jnp.sum(g * b), jnp.nan)
return f, g, reg_ot
```

To test both solvers, we run simulations using a random seed to generate random point clouds of size \(n\). Random generation is carried out using jax.random, to ensure reproducibility. A solver provides three pieces of info: the function (using our simple common API), its numerical environment and its name.

```
[7]:
```

```
dim = 3
def run_simulation(rng, n, π, threshold, solver_spec):
# setting global variables helps avoir a timeit bug.
global solver_
global a, b, x , y
# extract specificities of solver.
solver_, env, name = solver_spec
# draw data at random using JAX
rng, *rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (n, dim)) + 0.1
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (n,))
a = a / jnp.sum(a)
b = b / jnp.sum(b)
# map to numpy if needed
if env == 'np':
a, b, x, y = map(np.array,(a, b, x, y))
timeit_res = %timeit -o solver_(a, b, x, y, π, threshold)
out = solver_(a, b, x, y, π, threshold)
exec_time = np.nan if np.isnan(out[-1]) else timeit_res.best
return exec_time, out
```

Defines the two solvers used in this experiment:

```
[8]:
```

```
POT = (solve_ot, 'np', 'POT')
OTT = (solve_ott, 'jax', 'OTT')
```

## Runs simulations with varying \(n\) and \(\varepsilon\)#

We run simulations by setting the regularization strength π to either \(10^{-2}\) or \(10^{-1}\).

We consider \(n\) between sizes \(2^{8}= 256\) and \(2^{12}= 4096\). We do not go higher, because `POT`

runs into out-of-memory errors for \(2^{13}=8192\) in this RAM restricted colab environment. `OTT`

can avoid these by setting the flag `online`

to `True`

, as done in the tutorial for grids, and also handled by the GeomLoss toolbox. We leave the comparison with `geomloss`

to a different NB.

When `%timeit`

outputs execution time, **notice the warning message** highlighting the fact that, for `OTT`

, at least one run took significantly longer. That run is that doing the **JIT pre-compilation** of the procedure, suitable for that particular problem size \(n\). Once pre-compiled, subsequent runs are order of magnitudes faster, thanks to the `@jax.jit`

decorator added to `solve_ott`

.

```
[6]:
```

```
rng = jax.random.PRNGKey(0)
solvers = (POT, OTT)
n_range = 2 ** np.arange(8, 13)
π_range = 10 ** np.arange(-2.0, 0.0)
threshold = 1e-2
exec_time = {}
reg_ot = {}
for solver_spec in solvers:
solver, env, name = solver_spec
print('----- ', name)
exec_time[name] = np.ones((len(n_range), len(π_range))) * np.nan
reg_ot[name] = np.ones((len(n_range), len(π_range))) * np.nan
for i, n in enumerate(n_range):
for j, π in enumerate(π_range):
exec, out = run_simulation(rng, n, π, threshold, solver_spec)
exec_time[name][i, j] = exec
reg_ot[name][i, j] = out[-1]
```

```
----- POT
10 loops, best of 5: 43.7 ms per loop
100 loops, best of 5: 11.9 ms per loop
1 loop, best of 5: 230 ms per loop
10 loops, best of 5: 41.4 ms per loop
1 loop, best of 5: 33.4 s per loop
10 loops, best of 5: 155 ms per loop
1 loop, best of 5: 2min 13s per loop
1 loop, best of 5: 367 ms per loop
1 loop, best of 5: 6min 21s per loop
1 loop, best of 5: 1.22 s per loop
----- OTT
The slowest run took 66.78 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 11.2 ms per loop
1000 loops, best of 5: 1.04 ms per loop
The slowest run took 128.37 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 6.12 ms per loop
1000 loops, best of 5: 1.08 ms per loop
The slowest run took 94.84 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 8.95 ms per loop
1000 loops, best of 5: 1.42 ms per loop
The slowest run took 33.90 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 24 ms per loop
100 loops, best of 5: 3.47 ms per loop
The slowest run took 8.19 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 112 ms per loop
100 loops, best of 5: 14.3 ms per loop
```

## Plots results in terms of time and difference in objective.#

When the algorithm does not converge within the maximal number of 1000 iterations, or runs into numerical issues, the solver returns a NaN and that point does not appear in the plot.

```
[24]:
```

```
list_legend = []
fig = plt.figure(figsize=(14,8))
for solver_spec, marker, col in zip(solvers,('p','o'), ('blue','red')):
solver, env, name = solver_spec
p = plt.plot(exec_time[name], marker=marker, color=col,
markersize=16, markeredgecolor='k', lw=3)
p[0].set_linestyle('dotted')
p[1].set_linestyle('solid')
list_legend += [name + r' $\varepsilon $=' + "{:.2g}".format(π) for π in π_range]
plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.legend(list_legend)
plt.yscale('log')
plt.xlabel('dimension $n$')
plt.ylabel('time (s)')
plt.title(r'Execution Time vs Dimension for OTT and POT for two $\varepsilon$ values')
plt.show()
```

For good measure, we also show the differences in *objectives* between the two solvers. We substract the objective returned by `POT`

to that returned by `OTT`

.

Since the problem is evaluated in its dual form, a *higher* objective is *better*, and therefore a positive difference denotes better performance for `OTT`

. White areas stand for values for which `POT`

did not converge (either because it has exhausted the maximal number of iterations or experienced numerical issues).

```
[21]:
```

```
fig = plt.figure(figsize=(12,8))
ax = plt.gca()
im = ax.imshow(reg_ot['OTT'].T - reg_ot['POT'].T)
plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.yticks(ticks=np.arange(len(π_range)), labels=π_range)
plt.xlabel('dimension $n$')
plt.ylabel(r'regularization $\varepsilon$')
plt.title('Gap in objective, >0 when OTT is better')
divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
plt.colorbar(im, cax=cax)
plt.show()
```

```
[45]:
```

```
for name in ('POT','OTT'):
print('----', name)
print('Objective')
print(reg_ot[name])
print('Execution Time')
print(exec_time[name])
```

```
---- POT
Objective
[[-0.00862313 -0.79116929]
[-0.02666368 -0.93283839]
[ nan -1.07958862]
[ nan -1.22432204]
[ nan -1.36762311]]
Time
[[0.04367424 0.01185102]
[0.22960342 0.04137421]
[ nan 0.15465033]
[ nan 0.3669143 ]
[ nan 1.21968372]]
---- OTT
Objective
[[-0.00783848 -0.79117149]
[-0.02610656 -0.93283963]
[-0.05083928 -1.07959068]
[-0.06328616 -1.21402502]
[-0.07956241 -1.35710597]]
Time
[[0.01124264 0.00103751]
[0.00612156 0.00107929]
[0.00895449 0.00142238]
[0.02404206 0.00346715]
[0.11208566 0.01432985]]
```