# OTT vs. POT#

The Python Optimal Transport (POT) toolbox paved the way for much progress in OT. POT implements several OT solvers (LP and regularized), and is complemented with various tools (e.g., barycenters, domain adaptation, Gromov-Wasserstein distances, sliced W, etc.).

The goal of this notebook is to compare the performance OTT's and POT's Sinkhorn solvers. OTT benefits from just-in-time compilation, which should give it an edge.

The comparisons carried out below have limitations: minor modifications in the setup (e.g., data distributions, tolerance thresholds, acceleratorβ¦) could have an impact on these results. Feel free to change these settings and experiment by yourself!

import sys

!pip install -q git+https://github.com/ott-jax/ott@main
!pip install -q POT

import timeit

import jax
import jax.numpy as jnp
import numpy as np
import ot

import matplotlib.pyplot as plt
import mpl_toolkits.axes_grid1

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

plt.rc("font", size=20)


## Regularized OT in a nutshell#

We consider two probability measures $$\mu,\nu$$ compared with the squared-Euclidean distance, $$c(x,y)=\|x-y\|^2$$. These measures are discrete and of the same size in this notebook:

$\mu=\sum_{i=1}^n a_i\delta_{x_i}, \nu =\sum_{j=1}^n b_j\delta_{y_j},$

to define the OT problem in its primal form, $$$\min_{P \in U(a,b)} \langle C, P \rangle - \varepsilon H(P).$$$

where $$U(a,b):=\{P \in \mathbf{R}_+^{n\times n}, P\mathbf{1}_{n}=b, P^T\mathbf{1}_n=b\}$$, and $$C = [ \|x_i - y_j \|^2 ]_{i,j}\in \mathbf{R}_+^{n\times n}$$.

That problem is equivalent to the following dual form, $$$\max_{f, g} \langle a, f \rangle + \langle b, g \rangle - \varepsilon \langle e^{f/\varepsilon},Ke^{g/\varepsilon} \rangle.$$$

These two problems are solved by OTT and POT using the Sinkhorn iterations with a simple initialization for $$u$$, and subsequent updates $$v \leftarrow a / K^Tu, u \leftarrow b / Kv$$, where $$K:=e^{-C/\varepsilon}$$.

Upon convergence to fixed points $$u^*, v^*$$, one has $$$P^*=D(u^*)KD(v^*)$$$$or, alternatively,$$$$f^*, g^* = \varepsilon \log(u^*), \varepsilon\log(v^*)$$$

## OTT and POT implementation#

Both toolboxes carry out Sinkhorn updates using either the formulas above directly (this corresponds to lse_mode=False in OTT and method='sinkhorn' in POT) or using slightly slower but more robust approaches:

OTT relies on log-space iterations (lse_mode=True), whereas POT, uses a stabilization trick (method='sinkhorn_stabilized') designed to avoid numerical overflows, while still benefiting from the speed given by matrix vector products.

The default behavior of OTT and POT is to carry out these updates until $$\|u\circ Kv - a\|_2 + \|v\circ K^Tu - b\|_2$$ is smaller than the user-defined threshold.

## Common API for OTT and POT#

We will compare in our experiments OTT vs. POT in their more stable setups (lse_mode and stabilized). We define a common API for both, making sure their results are comparable. That API takes as inputs the measuresβ info, the targeted $$\varepsilon$$ value and the threshold used to terminate the algorithm. We set a maximum of 1000 iterations for both.

def solve_ot(a, b, x, y, π, threshold):
_, log = ot.sinkhorn(
a,
b,
ot.dist(x, y),
π,
stopThr=threshold,
method="sinkhorn_stabilized",
log=True,
numItermax=1000,
)
f, g = π * log["logu"], π * log["logv"]
f, g = f - np.mean(f), g + np.mean(
f
)  # center variables, useful if one wants to compare them
reg_ot = (
np.sum(f * a) + np.sum(g * b) if log["err"][-1] < threshold else np.nan
)
return f, g, reg_ot

@jax.jit
def solve_ott(a, b, x, y, π, threshold):
geom = pointcloud.PointCloud(x, y, epsilon=π)
prob = linear_problem.LinearProblem(geom, a=a, b=b)

solver = sinkhorn.Sinkhorn(
threshold=threshold, lse_mode=True, max_iterations=1000
)
out = solver(prob)

f, g = out.f, out.g
f, g = f - np.mean(f), g + np.mean(
f
)  # center variables, useful if one wants to compare them
reg_ot = jnp.where(out.converged, jnp.sum(f * a) + jnp.sum(g * b), jnp.nan)
return f, g, reg_ot


To test both solvers, we run simulations using a random seed to generate random point clouds of size $$n$$. Random generation is carried out using PRNGKey(), to ensure reproducibility. A solver provides three pieces of info: the function (using our simple common API), its numerical environment and its name.

dim = 3

def run_simulation(rng, n, π, threshold, solver_spec):
#  setting global variables helps avoir a timeit bug.
global solver_
global a, b, x, y

# extract specificities of solver.
solver_, env, name = solver_spec

# draw data at random using JAX
rng, *rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (n, dim)) + 0.1
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (n,))
a = a / jnp.sum(a)
b = b / jnp.sum(b)

# map to numpy if needed
if env == "np":
a, b, x, y = map(np.array, (a, b, x, y))

timeit_res = %timeit -o solver_(a, b, x, y, π, threshold)
out = solver_(a, b, x, y, π, threshold)
exec_time = np.nan if np.isnan(out[-1]) else timeit_res.best
return exec_time, out


Defines the two solvers used in this experiment:

POT = (solve_ot, "np", "POT")
OTT = (solve_ott, "jax", "OTT")


## Run simulations with varying $$n$$ and $$\varepsilon$$#

We run simulations by setting the regularization strength $$\varepsilon$$ to either $$10^{-2}$$ or $$10^{-1}$$.

We consider $$n$$ between sizes $$2^{8}= 256$$ and $$2^{12}= 4096$$. We do not go higher, because POT runs into out-of-memory errors for $$2^{13}=8192$$ RAM. OTT can avoid these by setting the flag batch_size to, e.g., 1024, as done in the tutorial for grids, and also handled by the GeomLoss toolbox. We leave the comparison with GeomLoss to a different notebook.

When %timeit outputs execution time, notice the warning message highlighting the fact that, for OTT, at least one run took significantly longer. That run is that doing the JIT pre-compilation of the procedure, suitable for that particular problem size $$n$$. Once pre-compiled, subsequent runs are order of magnitudes faster, thanks to the jit() decorator added to solve_ott.

rng = jax.random.PRNGKey(0)
solvers = (POT, OTT)
n_range = 2 ** np.arange(8, 13)
π_range = 10 ** np.arange(-2.0, 0.0)

threshold = 1e-2

exec_time = {}
reg_ot = {}
for solver_spec in solvers:
solver, env, name = solver_spec
print("----- ", name)
exec_time[name] = np.ones((len(n_range), len(π_range))) * np.nan
reg_ot[name] = np.ones((len(n_range), len(π_range))) * np.nan
for i, n in enumerate(n_range):
for j, π in enumerate(π_range):
t, out = run_simulation(rng, n, π, threshold, solver_spec)
exec_time[name][i, j] = t
reg_ot[name][i, j] = out[-1]

-----  POT
10 loops, best of 5: 43.7 ms per loop
100 loops, best of 5: 11.9 ms per loop
1 loop, best of 5: 230 ms per loop
10 loops, best of 5: 41.4 ms per loop
1 loop, best of 5: 33.4 s per loop
10 loops, best of 5: 155 ms per loop
1 loop, best of 5: 2min 13s per loop
1 loop, best of 5: 367 ms per loop
1 loop, best of 5: 6min 21s per loop
1 loop, best of 5: 1.22 s per loop
-----  OTT
The slowest run took 66.78 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 11.2 ms per loop
1000 loops, best of 5: 1.04 ms per loop
The slowest run took 128.37 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 6.12 ms per loop
1000 loops, best of 5: 1.08 ms per loop
The slowest run took 94.84 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 8.95 ms per loop
1000 loops, best of 5: 1.42 ms per loop
The slowest run took 33.90 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 24 ms per loop
100 loops, best of 5: 3.47 ms per loop
The slowest run took 8.19 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 112 ms per loop
100 loops, best of 5: 14.3 ms per loop


## Plot results in terms of time and difference in objective#

When the algorithm does not converge within the maximal number of 1000 iterations, or runs into numerical issues, the solver returns a NaN and that point does not appear in the plot.

list_legend = []
fig = plt.figure(figsize=(14, 8))

for solver_spec, marker, col in zip(solvers, ("p", "o"), ("blue", "red")):
solver, env, name = solver_spec
p = plt.plot(
exec_time[name],
marker=marker,
color=col,
markersize=16,
markeredgecolor="k",
lw=3,
)
p[0].set_linestyle("dotted")
p[1].set_linestyle("solid")
list_legend += [name + r"  $\varepsilon$=" + f"{π:.2g}" for π in π_range]

plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.legend(list_legend)
plt.yscale("log")
plt.xlabel("dimension $n$")
plt.ylabel("time (s)")
plt.title(
r"Execution Time vs Dimension for OTT and POT for two $\varepsilon$ values"
)
plt.show()


For good measure, we also show the differences in objectives between the two solvers. We subtract the objective returned by POT to that returned by OTT.

Since the problem is evaluated in its dual form, a higher objective is better, and therefore a positive difference denotes better performance for OTT. White areas stand for values for which POT did not converge (either because it has exhausted the maximal number of iterations or experienced numerical issues).

fig = plt.figure(figsize=(12, 8))
ax = plt.gca()
im = ax.imshow(reg_ot["OTT"].T - reg_ot["POT"].T)
plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.yticks(ticks=np.arange(len(π_range)), labels=π_range)
plt.xlabel("dimension $n$")
plt.ylabel(r"regularization $\varepsilon$")
plt.title("Gap in objective, >0 when OTT is better")
divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax)
plt.colorbar(im, cax=cax)
plt.show()

for name in ("POT", "OTT"):
print("----", name)
print("Objective")
print(reg_ot[name])
print("Execution Time")
print(exec_time[name])

---- POT
Objective
[[-0.00862313 -0.79116929]
[-0.02666368 -0.93283839]
[        nan -1.07958862]
[        nan -1.22432204]
[        nan -1.36762311]]
Time
[[0.04367424 0.01185102]
[0.22960342 0.04137421]
[       nan 0.15465033]
[       nan 0.3669143 ]
[       nan 1.21968372]]
---- OTT
Objective
[[-0.00783848 -0.79117149]
[-0.02610656 -0.93283963]
[-0.05083928 -1.07959068]
[-0.06328616 -1.21402502]
[-0.07956241 -1.35710597]]
Time
[[0.01124264 0.00103751]
[0.00612156 0.00107929]
[0.00895449 0.00142238]
[0.02404206 0.00346715]
[0.11208566 0.01432985]]