# Tracking progress of ott.solvers#

This tutorial shows how to track progress and errors of the following solvers:

We’ll see that we simply need to provide a callback function to the solvers.

import sys

%pip install -q git+https://github.com/ott-jax/ott@main
%pip install -q tqdm

import tqdm

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

from ott import utils
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr


## How to track progress#

ott offers a simple and flexible mechanism that works well with jit(), and applies to both the functional interface and the class interface.

The solvers Sinkhorn, low-rank Sinkhorn, and GromovWasserstein only report progress if we pass a callback function with a specific signature. The callback is then called at each iteration using io_callback().

### Callback function signature#

The required signature of the callback function is: (status: Tuple[ndarray, ndarray, ndarray, NamedTuple]) -> None. The arguments are:

• status: a tuple of:

## Tracking progress of Sinkhorn solvers#

Let’s start with the Sinkhorn solver without any tracking (the default behavior):

rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs, (n_x, d))
y = jax.random.normal(rngs, (n_y, d)) + 0.5
geom = pointcloud.PointCloud(x, y)


This problem is very simple, so the Sinkhorn solver converges after only 7 iterations.

solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom)

print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

Converged: True, #iters: 7, cost: 1.2429015636444092


For small problems such as this one, it’s fine to not track progress (the default behavior). However when tackling larger problems, we might want to track various values that the Sinkhorn solver updates at each iteration.

### Tracking progress of Sinkhorn via the functional interface#

Here are a few examples of how to track progress for Sinkhorn and low-rank Sinkhorn.

#### With the default callback function#

ott.utils provides a default_progress_fn(), which returns a callback function that simply prints the current iteration and the error. Let’s pass this basic callback as a static argument to solve():

solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
progress_fn = utils.default_progress_fn()
ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)

print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

10 / 2000 -- 0.049124784767627716
20 / 2000 -- 0.019962385296821594
30 / 2000 -- 0.00910455733537674
40 / 2000 -- 0.004339158535003662
50 / 2000 -- 0.002111591398715973
60 / 2000 -- 0.001037590205669403
70 / 2000 -- 0.0005124583840370178
Converged: True, #iters: 7, cost: 1.2429015636444092


This shows that the solver reports its metrics each 10 inner iterations (the default value).

#### With tqdm#

ott.utils also implements a tqdm_progress_fn() which returns a callback that can update a tqdm progress bar instead of just printing to the console.

with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)

  4%|██████▌                                                                                                                                                                                    | 7/200 [00:00<00:08, 23.28it/s, error: 5.124584e-04]

print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

Converged: True, #iters: 7, cost: 1.2429015636444092


### Tracking progress of Sinkhorn via the class interface#

Alternatively, we can provide the callback function to the Sinkhorn class and display the progress with tqdm. Let’s define a LinearProblem and run the solver:

prob = linear_problem.LinearProblem(geom)

with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)
ot = jax.jit(solver)(prob)

  4%|██████▌                                                                                                                                                                                    | 7/200 [00:00<00:08, 23.53it/s, error: 5.124584e-04]

print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

Converged: True, #iters: 7, cost: 1.2429015636444092


### Tracking progress of low-rank Sinkhorn via the class interface#

We can also track progress of the low-rank Sinkhorn solver. Because it currently doesn’t have a functional interface, we can only use the LRSinkhorn class interface:

prob = linear_problem.LinearProblem(geom)
rank = 2

with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)
ot = jax.jit(solver)(prob)

  8%|██████████████▉                                                                                                                                                                           | 16/200 [00:00<00:07, 23.11it/s, error: 3.191899e-04]

print(f"Converged: {ot.converged}, cost: {ot.reg_ot_cost}")

Converged: True, cost: 1.7340879440307617


## Tracking progress of the Gromov-Wasserstein solver#

We can track progress of the GromovWasserstein solver in the same way as with the Sinkhorn solvers. Let’s define a small QuadraticProblem, same as in the Gromov-Wasserstein notebook:

# Samples spiral
def sample_spiral(
):
angles = jnp.linspace(min_angle, max_angle, n)
data = []
noise = jax.random.normal(key, (2, n)) * noise
for i in range(n):
x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])
y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])
data.append([x, y])
data = jnp.array(data)
return data

# Samples Swiss roll
def sample_swiss_roll(
):
spiral = sample_spiral(
)
third_axis = jax.random.uniform(key, (n, 1)) * length
swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))
return swiss_roll

# Data parameters
n_spiral = 400
n_swiss_roll = 500
length = 10
noise = 0.8
min_angle = 0
max_angle = 9
angle_shift = 3

# Seed
seed = 14
key = jax.random.PRNGKey(seed)
key, *subkey = jax.random.split(key, 4)

spiral = sample_spiral(
n_spiral,
key=subkey,
min_angle=min_angle + angle_shift,
max_angle=max_angle + angle_shift,
noise=noise,
)
swiss_roll = sample_swiss_roll(
n_swiss_roll,
key=subkey[1:],
length=length,
min_angle=min_angle,
max_angle=max_angle,
)


We can now track the progress while the GromovWasserstein solver iterates:

geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)

solver = gromov_wasserstein.GromovWasserstein(
epsilon=100.0,
max_iterations=20,
store_inner_errors=True,  # needed for reporting errors
progress_fn=utils.default_progress_fn(),  # callback function
)
out = solver(prob)

n_outer_iterations = jnp.sum(out.costs != -1)
has_converged = bool(out.linear_convergence[n_outer_iterations - 1])

print(f"\n{n_outer_iterations} outer iterations were needed")
print(f"The outer loop of Gromov Wasserstein has converged: {out.converged}")
print(f"The final regularized GW cost is: {out.reg_gw_cost:.3f}")

1 / 20 -- -1.0
2 / 20 -- 0.1304362416267395
3 / 20 -- 0.0898154005408287
4 / 20 -- 0.06759566068649292
5 / 20 -- 0.05465700849890709

5 outer iterations were needed
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 1183.613


That’s it, this is how to track progress of Sinkhorn, low-rank Sinkhorn, and Gromov-Wasserstein solvers!