ott.utils.default_progress_fn

ott.utils.default_progress_fn#

ott.utils.default_progress_fn(fmt='{iter} / {max_iter} -- {error}', stream=None)[source]#

Return a callback that prints the progress when solving linear problems.

It prints the progress only when the error is computed, that is every inner_iterations.

Parameters:
  • fmt (str) – Format used to print. It can format iter, max_iter and error values.

  • stream (Optional[TextIOBase]) – Output IO stream.

Return type:

Callable[[Tuple[ndarray, ndarray, ndarray, NamedTuple]], None]

Returns:

A callback function accepting the following arguments

  • the current iteration number,

  • the number of inner iterations after which the error is computed,

  • the total number of iterations, and

  • the current SinkhornState or LRSinkhornState.

Examples

import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn

x = jax.random.normal(jax.random.PRNGKey(0), (100, 5))
geom = pointcloud.PointCloud(x)

progress_fn = utils.default_progress_fn()
solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
out = solve_fn(geom, progress_fn=progress_fn)