ott.utils.tqdm_progress_fn#
- ott.utils.tqdm_progress_fn(pbar, fmt='error: {error:0.6e}')[source]#
Return a callback that updates a progress bar when solving
linear problems
.It updates the progress bar only when the error is computed, that is every
inner_iterations
.- Parameters:
- Return type:
Callable
[[Tuple
[ndarray
,ndarray
,ndarray
,NamedTuple
]],None
]- Returns:
A callback function accepting the following arguments
the current iteration number,
the number of inner iterations after which the error is computed,
the total number of iterations, and
the current
SinkhornState
orLRSinkhornState
.
Examples
import tqdm import jax import jax.numpy as jnp from ott import utils from ott.geometry import pointcloud from ott.solvers.linear import sinkhorn x = jax.random.normal(jax.random.PRNGKey(0), (100, 5)) geom = pointcloud.PointCloud(x) with tqdm.tqdm() as pbar: progress_fn = utils.tqdm_progress_fn(pbar) solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"]) out = solve_fn(geom, progress_fn=progress_fn)