ott.core.sinkhorn.Sinkhorn#

class ott.core.sinkhorn.Sinkhorn(lse_mode=True, threshold=0.001, norm_error=1, inner_iterations=10, min_iterations=0, max_iterations=2000, momentum=None, anderson=None, parallel_dual_updates=False, use_danskin=None, implicit_diff=ImplicitDiff(solver_fun=<function cg>, ridge_kernel=0.0, ridge_identity=0.0, symmetric=False, precondition_fun=None), initializer=<ott.core.initializers.DefaultInitializer object>, jit=True)[source]#

A Sinkhorn solver for linear reg-OT problem.

A Sinkhorn solver takes a linear OT problem object as an input and returns a SinkhornOutput object that contains all the information required to compute transports. See sinkhorn() for a functional wrapper.

Parameters
  • lse_mode (bool) – True for log-sum-exp computations, False for kernel multiplication.

  • threshold (float) – tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution when either or both tau_a and tau_b are 1.0 (balanced or semi-balanced problem), or the relative change between two successive solutions in the unbalanced case.

  • norm_error (int) – power used to define p-norm of error for marginal/target.

  • inner_iterations (int) – the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead.

  • min_iterations (int) – the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.

  • max_iterations (int) – the maximum number of Sinkhorn iterations. If max_iterations is equal to min_iterations, sinkhorn iterations are run by default using a jax.lax.scan() loop rather than a custom, unroll-able jax.lax.while_loop() that monitors convergence. In that case the error is not monitored and the converged flag will return False as a consequence.

  • momentum (Optional[Momentum]) – a Momentum instance. See ott.core.momentum

  • anderson (Optional[AndersonAcceleration]) – an AndersonAcceleration instance. See ott.core.anderson.

  • implicit_diff (Optional[ImplicitDiff]) – instance used to solve implicit differentiation. Unrolls iterations if None.

  • parallel_dual_updates (bool) – updates potentials or scalings in parallel if True, sequentially (in Gauss-Seidel fashion) if False.

  • use_danskin (Optional[bool]) – when True, it is assumed the entropy regularized cost is is evaluated using optimal potentials that are frozen, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with implicit_differentiation) when the algorithm has converged with a low tolerance.

  • jit (bool) – if True, automatically jits the function upon first call. Should be set to False when used in a function that is jitted by the user, or when computing gradients (in which case the gradient function should be jitted by the user)

  • initializer (SinkhornInitializer) – how to compute the initial potentials/scalings.

Methods

init_state(ot_prob, init)

Return the initial state of the loop.

kernel_step(ot_prob, state, iteration)

Sinkhorn multiplicative update.

lse_step(ot_prob, state, iteration)

Sinkhorn LSE update.

one_iteration(ot_prob, state, iteration, ...)

Carries out sinkhorn iteration.

output_from_state(ot_prob, state)

Create an output from a loop state.

Attributes

norm_error

Powers used to compute the p-norm between marginal/target.

outer_iterations

Upper bound on number of times inner_iterations are carried out.