ott.utils.default_progress_fn#

ott.utils.default_progress_fn(status, *args)[source]#

Callback function that reports progress of Sinkhorn by printing to the console.

It prints the progress only when the error is computed, that is every inner_iterations.

Note

This function is called during solver iterations via id_tap() so the solver execution remains jittable.

Parameters:
Return type:

None

Returns:

Nothing, just prints.

Examples

If instead of printing you want to report progress using a progress bar such as tqdm, then simply provide a slightly modified version of this callback, for instance:

import jax
import numpy as np
from tqdm import tqdm

from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

def progress_fn(status, *args):
  iteration, inner_iterations, total_iter, state = status
  iteration = int(iteration) + 1
  inner_iterations = int(inner_iterations)
  total_iter = int(total_iter)
  errors = np.asarray(state.errors).ravel()

  # Avoid reporting error on each iteration,
  # because errors are only computed every `inner_iterations`.
  if iteration % inner_iterations == 0:
    error_idx = max(0, iteration // inner_iterations - 1)
    error = errors[error_idx]

    pbar.set_postfix_str(f"error: {error:0.6e}")
    pbar.total = total_iter // inner_iterations
    pbar.update()

prob = linear_problem.LinearProblem(...)
solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)

with tqdm() as pbar:
  out_sink = jax.jit(solver)(prob)