Tracking progress of ott.solvers#

This tutorial shows how to track progress and errors of the following solvers:

We’ll see that we simply need to provide a callback function to the solvers.

import sys

if "google.colab" in sys.modules:
    %pip install -q git+https://github.com/ott-jax/ott@main
    %pip install -q tqdm
import tqdm

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

from ott import utils
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein

How to track progress#

ott offers a simple and flexible mechanism that works well with jit(), and applies to both the functional interface and the class interface.

The solvers Sinkhorn, low-rank Sinkhorn, and GromovWasserstein only report progress if we pass a callback function with a specific signature. The callback is then called at each iteration using callback().

Callback function signature#

The required signature of the callback function is: (status: Tuple[ndarray, ndarray, ndarray, NamedTuple]) -> None. The arguments are:

  • status: a tuple of:

    • the current iteration index (0-based),

    • the number of inner iterations after which the error is computed,

    • the total number of iterations, and

    • the current solver state: SinkhornState, or LRSinkhornState, or GWState. For technical reasons, the type of this argument in the signature is simply NamedTuple (the common super-type).

Tracking progress of Sinkhorn solvers#

Let’s start with the Sinkhorn solver without any tracking (the default behavior):

rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5
geom = pointcloud.PointCloud(x, y)

This problem is very simple, so the Sinkhorn solver converges after only 7 iterations.

solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom)

print(
    f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 70, cost: 1.2429015636444092

For small problems such as this one, it’s fine to not track progress (the default behavior). However when tackling larger problems, we might want to track various values that the Sinkhorn solver updates at each iteration.

Tracking progress of Sinkhorn via the functional interface#

Here are a few examples of how to track progress for Sinkhorn and low-rank Sinkhorn.

With the default callback function#

ott.utils provides a default_progress_fn(), which returns a callback function that simply prints the current iteration and the error. Let’s pass this basic callback as a static argument to solve():

solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
progress_fn = utils.default_progress_fn()
ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)

print(
    f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
10 / 2000 -- 0.04912472516298294
20 / 2000 -- 0.019962534308433533
30 / 2000 -- 0.009104534983634949
40 / 2000 -- 0.004339255392551422
50 / 2000 -- 0.0021116361021995544
60 / 2000 -- 0.001037605106830597
70 / 2000 -- 0.0005124807357788086
Converged: True, #iters: 70, cost: 1.2429015636444092

This shows that the solver reports its metrics each 10 inner iterations (the default value).

With tqdm#

ott.utils also implements a tqdm_progress_fn() which returns a callback that can update a tqdm progress bar instead of just printing to the console.

with tqdm.tqdm() as pbar:
    progress_fn = utils.tqdm_progress_fn(pbar)
    solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
    ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)
  4%|▎         | 7/200 [00:00<00:22,  8.57it/s, error: 5.124807e-04]
print(
    f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 70, cost: 1.2429015636444092

Tracking progress of Sinkhorn via the class interface#

Alternatively, we can provide the callback function to the Sinkhorn class and display the progress with tqdm. Let’s define a LinearProblem and run the solver:

prob = linear_problem.LinearProblem(geom)

with tqdm.tqdm() as pbar:
    progress_fn = utils.tqdm_progress_fn(pbar)
    solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)
  4%|▎         | 7/200 [00:00<00:23,  8.27it/s, error: 5.124807e-04]
print(
    f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 70, cost: 1.2429015636444092

Tracking progress of low-rank Sinkhorn via the class interface#

We can also track progress of the low-rank Sinkhorn solver. Because it currently doesn’t have a functional interface, we can only use the LRSinkhorn class interface:

prob = linear_problem.LinearProblem(geom)
rank = 2

with tqdm.tqdm() as pbar:
    progress_fn = utils.tqdm_progress_fn(pbar)
    solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)
  8%|▊         | 16/200 [00:02<00:25,  7.10it/s, error: 3.223309e-04]
print(f"Converged: {ot.converged}, cost: {ot.reg_ot_cost}")
Converged: True, cost: 1.7340872287750244

Tracking progress of the Gromov-Wasserstein solver#

We can track progress of the GromovWasserstein solver in the same way as with the Sinkhorn solvers. Let’s define a small QuadraticProblem, same as in the Gromov-Wasserstein notebook:

# Samples spiral


def sample_spiral(
    n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0
):
    radius = jnp.linspace(min_radius, max_radius, n)
    angles = jnp.linspace(min_angle, max_angle, n)
    data = []
    noise = jax.random.normal(key, (2, n)) * noise
    for i in range(n):
        x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])
        y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])
        data.append([x, y])
    data = jnp.array(data)
    return data


# Samples Swiss roll
def sample_swiss_roll(
    n, min_radius, max_radius, length, key, min_angle=0, max_angle=10, noise=0.1
):
    spiral = sample_spiral(
        n, min_radius, max_radius, key[0], min_angle, max_angle, noise
    )
    third_axis = jax.random.uniform(key[1], (n, 1)) * length
    swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))
    return swiss_roll


# Data parameters
n_spiral = 400
n_swiss_roll = 500
length = 10
min_radius = 3
max_radius = 10
noise = 0.8
min_angle = 0
max_angle = 9
angle_shift = 3

# Seed
seed = 14
key = jax.random.PRNGKey(seed)
key, *subkey = jax.random.split(key, 4)
spiral = sample_spiral(
    n_spiral,
    min_radius,
    max_radius,
    key=subkey[0],
    min_angle=min_angle + angle_shift,
    max_angle=max_angle + angle_shift,
    noise=noise,
)
swiss_roll = sample_swiss_roll(
    n_swiss_roll,
    min_radius,
    max_radius,
    key=subkey[1:],
    length=length,
    min_angle=min_angle,
    max_angle=max_angle,
)

We can now track the progress while the GromovWasserstein solver iterates:

geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy)

solver = gromov_wasserstein.GromovWasserstein(
    epsilon=100.0,
    max_iterations=20,
    store_inner_errors=True,  # needed for reporting errors
    progress_fn=utils.default_progress_fn(),  # callback function
)
out = solver(prob)

n_outer_iterations = jnp.sum(out.costs != -1)
has_converged = bool(out.linear_convergence[n_outer_iterations - 1])

print(f"\n{n_outer_iterations} outer iterations were needed")
print(f"The outer loop of Gromov Wasserstein has converged: {out.converged}")
print(f"The final regularized GW cost is: {out.reg_gw_cost:.3f}")
1 / 20 -- -1.0
2 / 20 -- 0.13043621182441711
3 / 20 -- 0.08981533348560333
4 / 20 -- 0.06759564578533173
5 / 20 -- 0.0546572208404541

5 outer iterations were needed
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 1183.617

That’s it, this is how to track progress of Sinkhorn, low-rank Sinkhorn, and Gromov-Wasserstein solvers!