Tracking progress of ott.solvers
#
This tutorial shows how to track progress and errors of the following solvers:
We’ll see that we simply need to provide a callback function to the solvers.
import sys
if "google.colab" in sys.modules:
%pip install -q git+https://github.com/ott-jax/ott@main
%pip install -q tqdm
import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from ott import utils
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
How to track progress#
ott
offers a simple and flexible mechanism that works well with jit()
, and applies to both the functional interface and the class interface.
The solvers Sinkhorn
, low-rank Sinkhorn
, and GromovWasserstein
only report progress if we pass a callback function with a specific signature. The callback is then called at each iteration using io_callback()
.
Callback function signature#
The required signature of the callback function is: (status: Tuple[ndarray, ndarray, ndarray, NamedTuple]) -> None
. The arguments are:
status: a tuple of:
the current iteration index (0-based),
the number of inner iterations after which the error is computed,
the total number of iterations, and
the current solver state:
SinkhornState
, orLRSinkhornState
, orGWState
. For technical reasons, the type of this argument in the signature is simplyNamedTuple
(the common super-type).
Tracking progress of Sinkhorn solvers#
Let’s start with the Sinkhorn
solver without any tracking (the default behavior):
rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5
geom = pointcloud.PointCloud(x, y)
This problem is very simple, so the Sinkhorn
solver converges after only 7 iterations.
solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom)
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 7, cost: 1.2429015636444092
For small problems such as this one, it’s fine to not track progress (the default behavior). However when tackling larger problems, we might want to track various values that the Sinkhorn
solver updates at each iteration.
Tracking progress of Sinkhorn via the functional interface#
Here are a few examples of how to track progress for Sinkhorn
and low-rank Sinkhorn
.
With the default callback function#
ott.utils
provides a default_progress_fn()
, which returns a callback function that simply prints the current iteration and the error. Let’s pass this basic callback as a static argument to solve()
:
solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
progress_fn = utils.default_progress_fn()
ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
10 / 2000 -- 0.049124784767627716
20 / 2000 -- 0.019962385296821594
30 / 2000 -- 0.00910455733537674
40 / 2000 -- 0.004339158535003662
50 / 2000 -- 0.002111591398715973
60 / 2000 -- 0.001037590205669403
70 / 2000 -- 0.0005124583840370178
Converged: True, #iters: 7, cost: 1.2429015636444092
This shows that the solver reports its metrics each 10 inner iterations (the default value).
With tqdm
#
ott.utils
also implements a tqdm_progress_fn()
which returns a callback that can update a tqdm progress bar instead of just printing to the console.
with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)
4%|██████▌ | 7/200 [00:00<00:08, 23.28it/s, error: 5.124584e-04]
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 7, cost: 1.2429015636444092
Tracking progress of Sinkhorn via the class interface#
Alternatively, we can provide the callback function to the Sinkhorn
class and display the progress with tqdm
. Let’s define a LinearProblem
and run the solver:
prob = linear_problem.LinearProblem(geom)
with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)
ot = jax.jit(solver)(prob)
4%|██████▌ | 7/200 [00:00<00:08, 23.53it/s, error: 5.124584e-04]
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 7, cost: 1.2429015636444092
Tracking progress of low-rank Sinkhorn via the class interface#
We can also track progress of the low-rank Sinkhorn solver. Because it currently doesn’t have a functional interface, we can only use the LRSinkhorn
class interface:
prob = linear_problem.LinearProblem(geom)
rank = 2
with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)
ot = jax.jit(solver)(prob)
8%|██████████████▉ | 16/200 [00:00<00:07, 23.11it/s, error: 3.191899e-04]
print(f"Converged: {ot.converged}, cost: {ot.reg_ot_cost}")
Converged: True, cost: 1.7340879440307617
Tracking progress of the Gromov-Wasserstein solver#
We can track progress of the GromovWasserstein
solver in the same way as with the Sinkhorn solvers. Let’s define a small QuadraticProblem
, same as in the Gromov-Wasserstein notebook:
# Samples spiral
def sample_spiral(
n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0
):
radius = jnp.linspace(min_radius, max_radius, n)
angles = jnp.linspace(min_angle, max_angle, n)
data = []
noise = jax.random.normal(key, (2, n)) * noise
for i in range(n):
x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])
y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])
data.append([x, y])
data = jnp.array(data)
return data
# Samples Swiss roll
def sample_swiss_roll(
n, min_radius, max_radius, length, key, min_angle=0, max_angle=10, noise=0.1
):
spiral = sample_spiral(
n, min_radius, max_radius, key[0], min_angle, max_angle, noise
)
third_axis = jax.random.uniform(key[1], (n, 1)) * length
swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))
return swiss_roll
# Data parameters
n_spiral = 400
n_swiss_roll = 500
length = 10
min_radius = 3
max_radius = 10
noise = 0.8
min_angle = 0
max_angle = 9
angle_shift = 3
# Seed
seed = 14
key = jax.random.PRNGKey(seed)
key, *subkey = jax.random.split(key, 4)
spiral = sample_spiral(
n_spiral,
min_radius,
max_radius,
key=subkey[0],
min_angle=min_angle + angle_shift,
max_angle=max_angle + angle_shift,
noise=noise,
)
swiss_roll = sample_swiss_roll(
n_swiss_roll,
min_radius,
max_radius,
key=subkey[1:],
length=length,
min_angle=min_angle,
max_angle=max_angle,
)
We can now track the progress while the GromovWasserstein
solver iterates:
geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy)
solver = gromov_wasserstein.GromovWasserstein(
epsilon=100.0,
max_iterations=20,
store_inner_errors=True, # needed for reporting errors
progress_fn=utils.default_progress_fn(), # callback function
)
out = solver(prob)
n_outer_iterations = jnp.sum(out.costs != -1)
has_converged = bool(out.linear_convergence[n_outer_iterations - 1])
print(f"\n{n_outer_iterations} outer iterations were needed")
print(f"The outer loop of Gromov Wasserstein has converged: {out.converged}")
print(f"The final regularized GW cost is: {out.reg_gw_cost:.3f}")
1 / 20 -- -1.0
2 / 20 -- 0.1304362416267395
3 / 20 -- 0.0898154005408287
4 / 20 -- 0.06759566068649292
5 / 20 -- 0.05465700849890709
5 outer iterations were needed
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 1183.613
That’s it, this is how to track progress of Sinkhorn
, low-rank Sinkhorn
, and Gromov-Wasserstein
solvers!