- class ott.core.sinkhorn.Sinkhorn(lse_mode=True, threshold=0.001, norm_error=1, inner_iterations=10, min_iterations=0, max_iterations=2000, momentum=None, anderson=None, parallel_dual_updates=False, use_danskin=None, implicit_diff=ImplicitDiff(solver_fun=<function cg>, ridge_kernel=0.0, ridge_identity=0.0, symmetric=False, precondition_fun=None), jit=True)#
A Sinkhorn solver for linear reg-OT problem implemented as a pytree.
A Sinkhorn solver takes a linear OT problem object as an input and returns a SinkhornOutput object that contains all the information required to compute transports. See function
sinkhornfor a wrapper.
Truefor log-sum-exp computations,
Falsefor kernel multiplication.
float) – tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution when either or both tau_a and tau_b are 1.0 (balanced or semi-balanced problem), or the relative change between two successive solutions in the unbalanced case.
int) – power used to define p-norm of error for marginal/target.
int) – the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead.
int) – the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.
int) – the maximum number of Sinkhorn iterations. If
max_iterationsis equal to
min_iterations, sinkhorn iterations are run by default using a
jax.lax.scanloop rather than a custom, unroll-able
jax.lax.while_loopthat monitors convergence. In that case the error is not monitored and the
convergedflag will return
Falseas a consequence.
Momentum]) – a Momentum instance. See ott.core.momentum
AndersonAcceleration]) – an AndersonAcceleration instance. See ott.core.anderson.
ImplicitDiff]) – instance used to solve implicit differentiation. Unrolls iterations if None.
bool) – updates potentials or scalings in parallel if True, sequentially (in Gauss-Seidel fashion) if False.
bool]) – when
True, it is assumed the entropy regularized cost is is evaluated using optimal potentials that are freezed, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with
implicit_differentiation) when the algorithm has converged with a low tolerance.
bool) – if True, automatically jits the function upon first call. Should be set to False when used in a function that is jitted by the user, or when computing gradients (in which case the gradient function should be jitted by the user)
Return the initial state of the loop.
kernel_step(ot_prob, state, iteration)
Sinkhorn multiplicative update.
lse_step(ot_prob, state, iteration)
Sinkhorn LSE update.
one_iteration(ot_prob, state, iteration, ...)
Carries out sinkhorn iteration.
Create an output from a loop state.
Upper bound on number of times inner_iterations are carried out.