ott.utils.tqdm_progress_fn

ott.utils.tqdm_progress_fn#

ott.utils.tqdm_progress_fn(pbar, fmt='error: {error:0.6e}')[source]#

Return a callback that updates a progress bar when solving linear problems.

It updates the progress bar only when the error is computed, that is every inner_iterations.

Parameters:
  • pbar (Any) – tqdm progress bar.

  • fmt (str) – Format used for the postfix. It can format iter, max_iter and error values.

Return type:

Callable[[Tuple[ndarray, ndarray, ndarray, NamedTuple]], None]

Returns:

A callback function accepting the following arguments

  • the current iteration number,

  • the number of inner iterations after which the error is computed,

  • the total number of iterations, and

  • the current SinkhornState or LRSinkhornState.

Examples

import tqdm

import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn

x = jax.random.normal(jax.random.PRNGKey(0), (100, 5))
geom = pointcloud.PointCloud(x)

with tqdm.tqdm() as pbar:
  progress_fn = utils.tqdm_progress_fn(pbar)
  solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
  out = solve_fn(geom, progress_fn=progress_fn)