ott.utils.default_progress_fn#
- ott.utils.default_progress_fn(status, *args)[source]#
Callback function that reports progress of
Sinkhorn
by printing to the console.It prints the progress only when the error is computed, that is every
inner_iterations
.Note
This function is called during solver iterations via
id_tap()
so the solver execution remainsjittable
.- Parameters:
status (
Tuple
[ndarray
,ndarray
,ndarray
,NamedTuple
]) –status consisting of:
the current iteration number
the number of inner iterations after which the error is computed
the total number of iterations
the current
SinkhornState
args (
Any
) – unused, seejax.experimental.host_callback
.
- Return type:
- Returns:
Nothing, just prints.
Examples
If instead of printing you want to report progress using a progress bar such as tqdm, then simply provide a slightly modified version of this callback, for instance:
import jax import numpy as np from tqdm import tqdm from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn def progress_fn(status, *args): iteration, inner_iterations, total_iter, state = status iteration = int(iteration) + 1 inner_iterations = int(inner_iterations) total_iter = int(total_iter) errors = np.asarray(state.errors).ravel() # Avoid reporting error on each iteration, # because errors are only computed every `inner_iterations`. if iteration % inner_iterations == 0: error_idx = max(0, iteration // inner_iterations - 1) error = errors[error_idx] pbar.set_postfix_str(f"error: {error:0.6e}") pbar.total = total_iter // inner_iterations pbar.update() prob = linear_problem.LinearProblem(...) solver = sinkhorn.Sinkhorn(progress_fn=progress_fn) with tqdm() as pbar: out_sink = jax.jit(solver)(prob)