OTT & POT#
The goal of this notebook is to compare OTT's
to
the Python Optimal Transport (POT) implementation of Sinkhorn
. POT
can also use a JAX
backbone, but unlike OTT
, it cannot benefit from just-in-time compilation. We will see this can play a role for smaller scale problems. We also compare their APIs and highlight a few differences.
The comparisons carried out below have limitations: minor modifications in the setup (e.g., data distributions, tolerance thresholds, acceleratorβ¦) could have an impact on these results. Feel free to change these settings and experiment by yourself!
import jax
import jax.numpy as jnp
import numpy as np
import ot
import matplotlib.pyplot as plt
import mpl_toolkits.axes_grid1
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
plt.rc("font", size=20)
Regularized OT in a nutshell#
We consider two probability measures \(\mu,\nu\) compared with the squared-Euclidean distance, \(c(x,y)=\|x-y\|^2\). These measures are discrete and of the same size in this notebook:
to define the OT problem in its primal form,
where \(U(a,b):=\{P \in \mathbf{R}_+^{n\times n}, P\mathbf{1}_{n}=b, P^T\mathbf{1}_n=b\}\), and \(C = [ \|x_i - y_j \|^2 ]_{i,j}\in \mathbf{R}_+^{n\times n}\), and \(H\) is the Shannon entropy of \(P\), \(H(P)=-\sum_{ij} P_{ij} \left(\log P_{ij}-1\right)\).
That problem is equivalent to the following dual form,
These two problems can be solved by OTT
and POT
using the Sinkhorn
iterations with a simple initialization for \(u\), and subsequent updates \(v \leftarrow a / K^Tu, u \leftarrow b / Kv\), where \(K:=e^{-C/\varepsilon}\).
Upon convergence to fixed points \(u^*, v^*\), one has \(P^*=D(u^*)KD(v^*)\) or, alternatively, \(f^*, g^* = \varepsilon \log(u^*), \varepsilon\log(v^*)\).
OTT and POT implementation#
Both toolboxes can carry out the Sinkhorn updates as described in the formulas above (this corresponds to lse_mode=False
in OTT
and method='sinkhorn'
in POT
), but most practitioners will find that doing computations in log-space yields more robust computations, notably in low regularization regimes.
OTT
relies on log-space iterations (lse_mode=True
), whereas POT
, uses a stabilization trick (method='sinkhorn_stabilized'
) to avoid numerical overflows, by re-updating the kernel matrix regularly.
The default behavior of OTT
is to carry out these updates until \(\|u\circ Kv - a\|_1 + \|v\circ K^Tu - b\|_1\) is smaller than the user-defined threshold
. POT
uses instead the \(\|\cdot\|_2\) norm of these terms. Thankfully, OTT
can consider other norms by setting the norm_error
parameter, in this case to 2
to facilitate comparisons.
Common API for OTT
and POT
#
We will compare in our experiments OTT
vs. POT
in their more stable setups (lse_mode
and log
respectively). We define a common API that takes as inputs the measuresβ information, the targeted \(\varepsilon\) value and the threshold
used to terminate the algorithm. We recover dual potential vectors \(f\) and \(g\), and the dual objective of these dual vectors (without the regularization, as done for POT
). We set a maximum of \(10,000\) iterations for both.
def solve_ot(a, b, x, y, π, threshold):
# you can also try "sinkhorn_stabilized", this is a bit faster but less stable for small π
method = "sinkhorn_log"
_, log = ot.sinkhorn(
a,
b,
ot.dist(x, y),
π,
stopThr=threshold,
method=method,
log=True,
numItermax=1000,
)
# dealing with POT quirks
logu = "log_u" if method == "sinkhorn_log" else "logu"
logv = "log_v" if method == "sinkhorn_log" else "logv"
n_iter_key = "niter" if method == "sinkhorn_log" else "n_iter"
# center dual variables
f, g = π * log[logu], π * log[logv]
f, g = f - np.mean(f), g + np.mean(f)
converged = log["err"][-1] < threshold
reg_ot = np.sum(f * a) + np.sum(g * b) if converged else jnp.nan
return f, g, reg_ot, log[n_iter_key]
@jax.jit
def solve_ott(a, b, x, y, π, threshold):
n = x.shape[0]
geom = pointcloud.PointCloud(x, y, epsilon=π)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
solver = sinkhorn.Sinkhorn(
threshold=threshold,
max_iterations=1000,
norm_error=2,
lse_mode=True,
)
out = solver(prob)
# center dual variables to facilitate comparison
f, g = out.f, out.g
f, g = f - np.mean(f), g + np.mean(f)
reg_ot = jnp.where(out.converged, jnp.sum(f * a) + jnp.sum(g * b), jnp.nan)
return f, g, reg_ot, out.n_iters
To test both solvers, we run simulations using a random seed to generate random point clouds of size \(n\). Random generation is carried out using key()
, to ensure reproducibility. A solver specification solver_spec
provides three items: the function, using our common API, its numerical environment and its name. Next, provide information on GPU used.
!nvidia-smi --query-gpu=gpu_name --format=csv
name
NVIDIA GeForce RTX 2080 Ti
NVIDIA GeForce RTX 2080 Ti
def sample_points(rng, n, dim):
rng, *rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, dim))
y = (jax.random.normal(rngs[1], (n, dim)) + 0.5) / 5
a = jax.random.uniform(rngs[2], (n,)) + 0.1
b = jax.random.uniform(rngs[3], (n,)) + 0.1
a, b = a / jnp.sum(a), b / jnp.sum(b)
return a, b, x, y
def run_simulation(a, b, x, y, π, threshold, solver_spec):
# extract specificities of solver.
solver, env, name = solver_spec
# run solver once
out = solver(a, b, x, y, π, threshold)
print(" n_iters:", out[-1], end=" |")
# record timings
timeit_res = %timeit -o solver(a, b, x, y, π, threshold)
exec_time = np.nan if np.isnan(out[-1]) else timeit_res.average
return exec_time, out
Defines the three solvers used in this experiment: POT
, POT
with a JAX
backend, and OTT
.
POT = (solve_ot, "np", "POT")
POT_jax = (solve_ot, "jax", "POT-jax-backbone")
OTT = (solve_ott, "jax", "OTT")
Run simulations with varying \(n\) and \(\varepsilon\)#
We run simulations by setting the regularization strength \(\varepsilon\) to either \(10^{-2}\) or \(10^{-1}\).
We consider \(n\) between sizes \(2^{8}= 256\) and \(2^{15}= 32768\). We do not go higher, because POT
runs into out-of-memory errors for \(2^{13}=8192\). OTT
can avoid these by setting the flag batch_size
to, e.g., 1024, as also handled by the GeomLoss toolbox. We leave the comparison with GeomLoss
to a future notebook.
When %timeit
outputs execution time, notice the warning message highlighting the fact that, for OTT
, at least one run took significantly longer. That run is that doing the JIT pre-compilation of the procedure, suitable for that particular problem size \(n\). Once pre-compiled, subsequent runs are order of magnitudes faster, thanks to the jit()
decorator added to solve_ott
.
solvers = (POT, POT_jax, OTT)
n_range = 2 ** np.arange(9, 15)
We consider in this notebook setting the epsilon
regularization using multiples of
of the mean of the cost matrix.
The 3 scales selected below can be seen as very-low, medium & high regularization regimes. Note that the default setting in OTT-JAX uses a slightly different rule (one twentieth of the standard deviation of the entries in the cost matrix).
π_scales = [0.01, 0.025, 0.05]
dim = 6
exec_time = {}
reg_ot_costs = {}
n_iters = {}
# setting global variables helps avoir a timeit bug.
global a, b, x, y, solver
for solver_spec in solvers:
solver, env, name = solver_spec
print("----- ", name)
exec_time[name] = np.ones((len(n_range), len(π_scales))) * np.nan
reg_ot_costs[name] = np.ones((len(n_range), len(π_scales))) * np.nan
n_iters[name] = np.ones((len(n_range), len(π_scales))) * np.nan
for j, π_scale in enumerate(π_scales):
for i, n in enumerate(n_range):
rng = jax.random.key(i)
# Compute a relevant scale for π
a, b, x, y = sample_points(rng, n, dim)
# this computes mean of cost matrix
epsilon_base = pointcloud.PointCloud(x, y).mean_cost_matrix
π = epsilon_base * π_scale
# map to numpy if needed
if env == "np":
a, b, x, y = map(np.array, (a, b, x, y))
# check dtype consistency across experiments
assert x.dtype == "float32"
# Set a threshold that scales with n
threshold_n = 0.01 / (n**0.33)
print(
"n:",
n,
", π_scale:",
π_scale,
f", π: {π:.5f}",
f", thr.: {threshold_n:.5f}",
end=" ",
)
t, out = run_simulation(a, b, x, y, π, threshold_n, solver_spec)
_, _, reg_ot_cost, n_it = out
exec_time[name][i, j] = t
reg_ot_costs[name][i, j] = reg_ot_cost
# Check convergence.
assert not jnp.isnan(reg_ot_cost)
n_iters[name][i, j] = n_it
----- POT
n: 512 , π_scale: 0.01 , π: 0.01720 , thr.: 0.00128 n_iters: 40 |234 ms Β± 5.82 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 1024 , π_scale: 0.01 , π: 0.01741 , thr.: 0.00102 n_iters: 40 |555 ms Β± 3.95 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 2048 , π_scale: 0.01 , π: 0.01690 , thr.: 0.00081 n_iters: 40 |2.74 s Β± 4.2 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 4096 , π_scale: 0.01 , π: 0.01703 , thr.: 0.00064 n_iters: 40 |13.6 s Β± 9.73 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 8192 , π_scale: 0.01 , π: 0.01704 , thr.: 0.00051 n_iters: 40 |52.3 s Β± 34.7 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 16384 , π_scale: 0.01 , π: 0.01695 , thr.: 0.00041 n_iters: 30 |2min 42s Β± 223 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 512 , π_scale: 0.025 , π: 0.04299 , thr.: 0.00128 n_iters: 20 |128 ms Β± 615 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 1024 , π_scale: 0.025 , π: 0.04352 , thr.: 0.00102 n_iters: 20 |238 ms Β± 1.65 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 2048 , π_scale: 0.025 , π: 0.04226 , thr.: 0.00081 n_iters: 20 |1.44 s Β± 8.14 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 4096 , π_scale: 0.025 , π: 0.04257 , thr.: 0.00064 n_iters: 20 |6.34 s Β± 13.7 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 8192 , π_scale: 0.025 , π: 0.04260 , thr.: 0.00051 n_iters: 20 |24.3 s Β± 31.4 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 16384 , π_scale: 0.025 , π: 0.04238 , thr.: 0.00041 n_iters: 20 |1min 36s Β± 225 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 512 , π_scale: 0.05 , π: 0.08598 , thr.: 0.00128 n_iters: 10 |82.9 ms Β± 1.34 ms per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 1024 , π_scale: 0.05 , π: 0.08704 , thr.: 0.00102 n_iters: 10 |160 ms Β± 706 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 2048 , π_scale: 0.05 , π: 0.08451 , thr.: 0.00081 n_iters: 10 |815 ms Β± 12.9 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 4096 , π_scale: 0.05 , π: 0.08513 , thr.: 0.00064 n_iters: 10 |3.47 s Β± 2.61 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 8192 , π_scale: 0.05 , π: 0.08520 , thr.: 0.00051 n_iters: 10 |13.3 s Β± 12.1 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 16384 , π_scale: 0.05 , π: 0.08476 , thr.: 0.00041 n_iters: 10 |52.2 s Β± 165 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
----- POT-jax-backbone
n: 512 , π_scale: 0.01 , π: 0.01720 , thr.: 0.00128 n_iters: 40 |222 ms Β± 3.84 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 1024 , π_scale: 0.01 , π: 0.01741 , thr.: 0.00102 n_iters: 40 |220 ms Β± 4.61 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 2048 , π_scale: 0.01 , π: 0.01690 , thr.: 0.00081 n_iters: 40 |215 ms Β± 4.39 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 4096 , π_scale: 0.01 , π: 0.01703 , thr.: 0.00064 n_iters: 40 |212 ms Β± 3.12 ms per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 8192 , π_scale: 0.01 , π: 0.01704 , thr.: 0.00051 n_iters: 40 |374 ms Β± 813 Β΅s per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 16384 , π_scale: 0.01 , π: 0.01695 , thr.: 0.00041 n_iters: 30 |1.12 s Β± 302 Β΅s per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 512 , π_scale: 0.025 , π: 0.04299 , thr.: 0.00128 n_iters: 20 |110 ms Β± 144 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 1024 , π_scale: 0.025 , π: 0.04352 , thr.: 0.00102 n_iters: 20 |132 ms Β± 9.8 ms per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 2048 , π_scale: 0.025 , π: 0.04226 , thr.: 0.00081 n_iters: 20 |98.5 ms Β± 84.3 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 4096 , π_scale: 0.025 , π: 0.04257 , thr.: 0.00064 n_iters: 20 |115 ms Β± 926 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 8192 , π_scale: 0.025 , π: 0.04260 , thr.: 0.00051 n_iters: 20 |199 ms Β± 283 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 16384 , π_scale: 0.025 , π: 0.04238 , thr.: 0.00041 n_iters: 20 |773 ms Β± 295 Β΅s per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 512 , π_scale: 0.05 , π: 0.08598 , thr.: 0.00128 n_iters: 10 |60.5 ms Β± 50.8 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 1024 , π_scale: 0.05 , π: 0.08704 , thr.: 0.00102 n_iters: 10 |64.1 ms Β± 2.08 ms per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 2048 , π_scale: 0.05 , π: 0.08451 , thr.: 0.00081 n_iters: 10 |65.1 ms Β± 823 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 4096 , π_scale: 0.05 , π: 0.08513 , thr.: 0.00064 n_iters: 10 |56.7 ms Β± 2.99 ms per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 8192 , π_scale: 0.05 , π: 0.08520 , thr.: 0.00051 n_iters: 10 |112 ms Β± 234 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 16384 , π_scale: 0.05 , π: 0.08476 , thr.: 0.00041 n_iters: 10 |432 ms Β± 284 Β΅s per loop (mean Β± std. dev. of 7 runs, 1 loop each)
----- OTT
n: 512 , π_scale: 0.01 , π: 0.01720 , thr.: 0.00128 n_iters: 40 |3 ms Β± 6.35 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 1024 , π_scale: 0.01 , π: 0.01741 , thr.: 0.00102 n_iters: 40 |4.38 ms Β± 8.52 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 2048 , π_scale: 0.01 , π: 0.01690 , thr.: 0.00081 n_iters: 40 |12.2 ms Β± 49.6 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 4096 , π_scale: 0.01 , π: 0.01703 , thr.: 0.00064 n_iters: 40 |39.4 ms Β± 32.8 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 8192 , π_scale: 0.01 , π: 0.01704 , thr.: 0.00051 n_iters: 40 |151 ms Β± 63.6 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 16384 , π_scale: 0.01 , π: 0.01695 , thr.: 0.00041 n_iters: 40 |594 ms Β± 78.9 Β΅s per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 512 , π_scale: 0.025 , π: 0.04299 , thr.: 0.00128 n_iters: 20 |1.69 ms Β± 24.2 Β΅s per loop (mean Β± std. dev. of 7 runs, 1,000 loops each)
n: 1024 , π_scale: 0.025 , π: 0.04352 , thr.: 0.00102 n_iters: 20 |2.25 ms Β± 5.76 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 2048 , π_scale: 0.025 , π: 0.04226 , thr.: 0.00081 n_iters: 20 |6.15 ms Β± 5.88 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 4096 , π_scale: 0.025 , π: 0.04257 , thr.: 0.00064 n_iters: 20 |19.7 ms Β± 4.54 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 8192 , π_scale: 0.025 , π: 0.04260 , thr.: 0.00051 n_iters: 20 |75.7 ms Β± 96.6 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 16384 , π_scale: 0.025 , π: 0.04238 , thr.: 0.00041 n_iters: 20 |299 ms Β± 438 Β΅s per loop (mean Β± std. dev. of 7 runs, 1 loop each)
n: 512 , π_scale: 0.05 , π: 0.08598 , thr.: 0.00128 n_iters: 10 |1.1 ms Β± 14.3 Β΅s per loop (mean Β± std. dev. of 7 runs, 1,000 loops each)
n: 1024 , π_scale: 0.05 , π: 0.08704 , thr.: 0.00102 n_iters: 10 |1.33 ms Β± 5.89 Β΅s per loop (mean Β± std. dev. of 7 runs, 1,000 loops each)
n: 2048 , π_scale: 0.05 , π: 0.08451 , thr.: 0.00081 n_iters: 10 |3.16 ms Β± 6.73 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 4096 , π_scale: 0.05 , π: 0.08513 , thr.: 0.00064 n_iters: 10 |10.1 ms Β± 1.03 Β΅s per loop (mean Β± std. dev. of 7 runs, 100 loops each)
n: 8192 , π_scale: 0.05 , π: 0.08520 , thr.: 0.00051 n_iters: 10 |37.9 ms Β± 39.5 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
n: 16384 , π_scale: 0.05 , π: 0.08476 , thr.: 0.00041 n_iters: 10 |150 ms Β± 265 Β΅s per loop (mean Β± std. dev. of 7 runs, 10 loops each)
Plot results: time and objective#
We plot below all 3 runs for each of the 3 solvers. When using POT
with a JAX backbone, the speed-up we get by using JIT
in OTT
is more apparent for smaller scale problems. Indeed, for larger scale problems, most of the compute effort is spent on kernel vector products, which, in this case, are comparably implemented across platforms.
list_legend = []
fig = plt.figure(figsize=(14, 8))
metric = exec_time
name = "Execution time"
for solver_spec, marker, col in zip(
solvers, ("p", "o", "d"), ("blue", "red", "green")
):
solver, env, name = solver_spec
p = plt.plot(
metric[name],
marker=marker,
color=col,
markersize=16,
markeredgecolor="k",
lw=3,
)
p[0].set_linestyle("-")
p[1].set_linestyle("--")
p[2].set_linestyle(":")
list_legend += [name + r" $\varepsilon $=" + f"{π:.2g}" for π in π_scales]
plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.legend(list_legend)
plt.yscale("log")
plt.xlabel("dimension $n$")
plt.ylabel(name)
plt.title(
r"Execution Time vs Dimension for OTT and POT for two $\varepsilon$ values"
)
plt.show()
meth = "OTT"
def plot_bsl(bsl):
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
im = ax.imshow(reg_ot_costs[meth].T - reg_ot_costs[bsl].T)
plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)
plt.yticks(ticks=np.arange(len(π_scales)), labels=π_scales)
plt.xlabel("dimension $n$")
plt.ylabel(r"regularization $\varepsilon$")
title = (
"Gap in objective "
+ bsl
+ " / "
+ meth
+ " , >0 when "
+ meth
+ " is better"
)
plt.title(title)
divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
plt.colorbar(im, cax=cax)
plt.show()
For good measure, we also show the differences in objectives between the two solvers. We subtract the objective returned by POT
and POT-JAX
to that returned by OTT
.
Since the problem is evaluated in its dual form, a higher objective is better, and therefore a positive difference denotes a better performance for OTT
.
plot_bsl("POT")
plot_bsl("POT-jax-backbone")