ott.utils.default_progress_fn#
- ott.utils.default_progress_fn(fmt='{iter} / {max_iter} -- {error}', stream=None)[source]#
Return a callback that prints the progress when solving
linear problems
.It prints the progress only when the error is computed, that is every
inner_iterations
.- Parameters:
fmt (
str
) – Format used to print. It can formatiter
,max_iter
anderror
values.stream (
Optional
[TextIOBase
]) – Output IO stream.
- Return type:
Callable
[[Tuple
[ndarray
,ndarray
,ndarray
,NamedTuple
]],None
]- Returns:
A callback function accepting the following arguments
the current iteration number,
the number of inner iterations after which the error is computed,
the total number of iterations, and
the current
SinkhornState
orLRSinkhornState
.
Examples
import jax import jax.numpy as jnp from ott import utils from ott.geometry import pointcloud from ott.solvers import linear x = jax.random.normal(jax.random.key(0), (100, 5)) geom = pointcloud.PointCloud(x) progress_fn = utils.default_progress_fn() solve_fn = jax.jit(linear.solve, static_argnames=["progress_fn"]) out = solve_fn(geom, progress_fn=progress_fn)