Tracking progress of ott.solvers
#
This tutorial shows how to track progress and errors of the following solvers:
We’ll see that we simply need to provide a callback function to the solvers.
import tqdm
import jax
import jax.numpy as jnp
from ott import utils
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers import linear
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
How to track progress#
ott
offers a simple and flexible mechanism that works well with jit()
, and applies to both the functional interface and the class interface.
The solvers Sinkhorn
, low-rank Sinkhorn
, and GromovWasserstein
only report progress if we pass a callback function with a specific signature. The callback is then called at each iteration using callback()
.
Callback function signature#
The required signature of the callback function is: (status: Tuple[ndarray, ndarray, ndarray, NamedTuple]) -> None
. The arguments are:
status: a tuple of:
the current iteration index (0-based),
the number of inner iterations after which the error is computed,
the total number of iterations, and
the current solver state:
SinkhornState
, orLRSinkhornState
, orGWState
. For technical reasons, the type of this argument in the signature is simplyNamedTuple
(the common super-type).
Tracking progress of Sinkhorn solvers#
Let’s start with the Sinkhorn
solver without any tracking (the default behavior):
rngs = jax.random.split(jax.random.key(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5
geom = pointcloud.PointCloud(x, y)
This problem is very simple, so the Sinkhorn
solver converges after only 7 iterations.
solve_fn = jax.jit(linear.solve)
ot = solve_fn(geom)
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 80, cost: 1.2124879360198975
For small problems such as this one, it’s fine to not track progress (the default behavior). However when tackling larger problems, we might want to track various values that the Sinkhorn
solver updates at each iteration.
Tracking progress of Sinkhorn via the functional interface#
Here are a few examples of how to track progress for Sinkhorn
and low-rank Sinkhorn
.
With the default callback function#
ott.utils
provides a default_progress_fn()
, which returns a callback function that simply prints the current iteration and the error. Let’s pass this basic callback as a static argument to solve()
:
solve_fn = jax.jit(linear.solve, static_argnames=["progress_fn"])
progress_fn = utils.default_progress_fn()
ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
10 / 2000 -- 0.058479130268096924
20 / 2000 -- 0.023402303457260132
30 / 2000 -- 0.012083128094673157
40 / 2000 -- 0.006561979651451111
50 / 2000 -- 0.003637164831161499
60 / 2000 -- 0.0020373016595840454
70 / 2000 -- 0.0011475533246994019
80 / 2000 -- 0.0006486847996711731
Converged: True, #iters: 80, cost: 1.2124879360198975
This shows that the solver reports its metrics each 10 inner iterations (the default value).
With tqdm
#
ott.utils
also implements a tqdm_progress_fn()
which returns a callback that can update a tqdm progress bar instead of just printing to the console.
with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solve_fn = jax.jit(linear.solve, static_argnames=["progress_fn"])
ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)
0it [00:00, ?it/s]
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 80, cost: 1.2124879360198975
Tracking progress of Sinkhorn via the class interface#
Alternatively, we can provide the callback function to the Sinkhorn
class and display the progress with tqdm
. Let’s define a LinearProblem
and run the solver:
prob = linear_problem.LinearProblem(geom)
with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)
ot = jax.jit(solver)(prob)
0it [00:00, ?it/s]
print(
f"Converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)
Converged: True, #iters: 80, cost: 1.2124879360198975
Tracking progress of low-rank Sinkhorn via the class interface#
We can also track progress of the low-rank Sinkhorn solver. Because it currently doesn’t have a functional interface, we can only use the LRSinkhorn
class interface:
prob = linear_problem.LinearProblem(geom)
rank = 2
with tqdm.tqdm() as pbar:
progress_fn = utils.tqdm_progress_fn(pbar)
solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)
ot = jax.jit(solver)(prob)
0it [00:00, ?it/s]
print(f"Converged: {ot.converged}, cost: {ot.reg_ot_cost}")
Converged: True, cost: 1.765258550643921
Tracking progress of the Gromov-Wasserstein solver#
We can track progress of the GromovWasserstein
solver in the same way as with the Sinkhorn solvers. Let’s define a small QuadraticProblem
, same as in the Gromov-Wasserstein notebook:
# Samples spiral
def sample_spiral(
n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0
):
radius = jnp.linspace(min_radius, max_radius, n)
angles = jnp.linspace(min_angle, max_angle, n)
data = []
noise = jax.random.normal(key, (2, n)) * noise
for i in range(n):
x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])
y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])
data.append([x, y])
data = jnp.array(data)
return data
# Samples Swiss roll
def sample_swiss_roll(
n, min_radius, max_radius, length, key, min_angle=0, max_angle=10, noise=0.1
):
spiral = sample_spiral(
n, min_radius, max_radius, key[0], min_angle, max_angle, noise
)
third_axis = jax.random.uniform(key[1], (n, 1)) * length
swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))
return swiss_roll
# Data parameters
n_spiral = 400
n_swiss_roll = 500
length = 10
min_radius = 3
max_radius = 10
noise = 0.8
min_angle = 0
max_angle = 9
angle_shift = 3
# Seed
seed = 14
key, *subkeys = jax.random.split(jax.random.key(seed), 4)
spiral = sample_spiral(
n_spiral,
min_radius,
max_radius,
key=subkeys[0],
min_angle=min_angle + angle_shift,
max_angle=max_angle + angle_shift,
noise=noise,
)
swiss_roll = sample_swiss_roll(
n_swiss_roll,
min_radius,
max_radius,
key=subkeys[1:],
length=length,
min_angle=min_angle,
max_angle=max_angle,
)
We can now track the progress while the GromovWasserstein
solver iterates:
geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy)
linear_solver = sinkhorn.Sinkhorn()
solver = gromov_wasserstein.GromovWasserstein(
linear_solver,
epsilon=100.0,
max_iterations=20,
store_inner_errors=True, # needed for reporting errors
progress_fn=utils.default_progress_fn(), # callback function
)
out = solver(prob)
n_outer_iterations = jnp.sum(out.costs != -1)
has_converged = bool(out.linear_convergence[n_outer_iterations - 1])
print(f"\n{n_outer_iterations} outer iterations were needed")
print(f"The outer loop of Gromov Wasserstein has converged: {out.converged}")
print(f"The final regularized GW cost is: {out.reg_gw_cost:.3f}")
1 / 20 -- -1.0
2 / 20 -- 0.13043582439422607
3 / 20 -- 0.08981513231992722
4 / 20 -- 0.06759561598300934
5 / 20 -- 0.05465685948729515
5 outer iterations were needed
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 1183.611
That’s it, this is how to track progress of Sinkhorn
, low-rank Sinkhorn
, and Gromov-Wasserstein
solvers!