ott.solvers.linear.sinkhorn.Sinkhorn#
- class ott.solvers.linear.sinkhorn.Sinkhorn(lse_mode=True, threshold=0.001, norm_error=1, inner_iterations=10, min_iterations=0, max_iterations=2000, momentum=None, anderson=None, parallel_dual_updates=False, recenter_potentials=False, use_danskin=None, implicit_diff=ImplicitDiff(solver=None, solver_kwargs=None, symmetric=False, precondition_fun=None), initializer=None, progress_fn=None)[source]#
Sinkhorn solver.
The Sinkhorn algorithm is a fixed point iteration that solves a regularized optimal transport (reg-OT) problem between two measures.
- Note on Sinkhorn updates:
The boolean flag
lse_modesets whether the algorithm is run in either:log-sum-exp mode (
lse_mode=True), in which case it is directly defined in terms of updates to f and g, using log-sum-exp computations. This requires access to the cost matrix \(C\), as it is stored, or possibly computed on the fly bygeom.kernel mode (
lse_mode=False), in which case it will require access to a matrix vector multiplication operator \(z \rightarrow K z\), where \(K\) is either instantiated from \(C\) as \(\exp(-C/\varepsilon)\), or provided directly. In that case, rather than optimizing on \(f\) and \(g\), it is more convenient to optimize on their so called scaling formulations, \(u := \exp(f / \varepsilon)\) and \(v := \exp(g / \varepsilon)\). While faster (applying matrices is faster than applyinglserepeatedly over lines), this mode is also less stable numerically, notably for smaller \(\varepsilon\).
In the source code, the variables
f_uorg_vcan be either regarded as potentials (real) or scalings (positive) vectors, depending on the choice oflse_modeby the user. Once optimization is carried out, we only return dual variables in potential form, i.e.fandg.In addition to standard Sinkhorn updates, the user can also use heavy-ball type updates using a
momentumparameter in ]0,2[. We also implement a strategy that tries to set that parameter adaptively atchg_momentum_fromiterations, as a function of progress in the error, as proposed in the literature.Another upgrade to the standard Sinkhorn updates provided to the users lies in using Anderson acceleration. This can be parameterized by setting the otherwise null
andersonto a positive integer. When selected,the algorithm will recompute, everyrefresh_anderson_frequency(set by default to 1) an extrapolation of the most recently computedandersoniterates. When using that option, notice that differentiation (if required) can only be carried out using implicit differentiation, and that all momentum related parameters are ignored.The
parallel_dual_updatesflag is set toFalseby default. In that setting,g_vis first updated using the latest values forf_uandg_v, before proceeding to updatef_uusing that new value forg_v. When the flag is set toTrue, bothf_uandg_vare updated simultaneously. Note that setting that choice toTruerequires using some form of averaging (e.g.momentum=0.5). Without this, and on its ownparallel_dual_updateswon’t work.- Differentiation:
The optimal solutions
fandgand the optimal objective (reg_ot_cost) outputted by the Sinkhorn algorithm can be differentiated w.r.t. relevant inputsgeom,aandb. In the default setting, the algorithm uses implicit differentiation of the optimality conditions (implicit_diffnot equal toNone). This has two consequences:The termination criterion used to stop Sinkhorn (cancellation of gradient of objective w.r.t.
f_uandg_v) is used to differentiate the dual Kantorovich potentialsfandg, given a change in the inputs. These changes are computed by solving a linear system. The optimality conditions of the entropy-regularized optimal transport problem can be analyzed as satisfying az=z'condition, which are then differentiated. It might be beneficial (e.g., as in [Cuturi et al., 2020]) to use a preconditioning functionprecondition_funto differentiate insteadh(z) = h(z').The objective
reg_ot_costreturned by Sinkhorn uses the so-called envelope theorem (a.k.a. Danskin’s theorem). In that case, because it is assumed that the gradients of the dual variablesf_uandg_vw.r.t. dual objective are zero (reflecting the fact that they are optimal), small variations inf_uandg_vdue to changes in inputs (such asgeom,aandb) are considered negligible. As a result,stop_gradientis applied on dual variablesf_uandg_vwhen evaluating thereg_ot_costobjective. Note that this approach is invalid when computing higher order derivatives. In that case theuse_danskinflag must be set toFalse.
An alternative yet more costly way to differentiate the outputs of the Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation of the Sinkhorn loop. This is possible because Sinkhorn iterations are wrapped in a custom fixed point iteration loop, defined in
fixed_point_loop, rather than a standard while loop. This is to ensure the end result of this fixed point loop can also be differentiated, if needed, using standard JAX operations. To ensure differentiability, thefixed_point_loop.fixpoint_iter_backproploop does checkpointing of state variables (heref_uandg_v) everyinner_iterations, and backpropagates automatically, block by block, through blocks ofinner_iterationsat a time.
Note
The Sinkhorn algorithm may not converge within the maximum number of iterations for possibly several reasons:
the regularizer (defined as
epsilonin the geometrygeomobject) is too small. Consider either switching tolse_mode=True(at the price of a slower execution), increasingepsilon, or, alternatively, if you are unable or unwilling to increaseepsilon, either increasemax_iterationsorthreshold.the probability weights
aandbdo not have the same total mass, while using a balanced (tau_a=tau_b=1.0) setup. Consider either normalizingaandb, or set eithertau_aand/ortau_b<1.0.OOMs issues may arise when storing either cost or kernel matrices that are too large in
geom. In the case where, thegeomgeometry is aPointCloud, some of these issues might be solved by setting theonlineflag toTrue. This will trigger a re-computation on the fly of the cost/kernel matrix.
The weight vectors
aandbcan be passed on with coordinates that have zero weight. This is then handled by relying on simple arithmetic forinfvalues that will likely arise (due to \(\log 0\) whenlse_modeisTrue, or divisions by zero whenlse_modeisFalse). Whenever that arithmetic is likely to produceNaNvalues (due to-inf * 0, or-inf - -inf) in the forward pass, we usejnp.whereconditional statements to carryinfrather thanNaNvalues. In reverse mode differentiation, the inputs corresponding to these 0 weights (a location x, or a row in the corresponding cost/kernel matrix), and the weight itself will haveNaNgradient values. This reflects that these gradients are undefined, since these points were not considered in the optimization and have therefore no impact on the output.
- Parameters:
lse_mode (
bool) –Truefor log-sum-exp computations,Falsefor kernel multiplication.threshold (
float) – tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution when either or bothtau_aandtau_bare \(1.0\) (balanced or semi-balanced problem), or the relative change between two successive solutions in the unbalanced case.norm_error (
int) – power used to define the \(p\)-norm used to quantify the magnitude of the gradients. This criterion is used to terminate the algorithm.inner_iterations (
int) – the Sinkhorn error is not recomputed at each iteration but everyinner_iterationsinstead.min_iterations (
int) – the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.max_iterations (
int) – the maximum number of Sinkhorn iterations. Ifmax_iterationsis equal tomin_iterations, Sinkhorn iterations are run by default using ajax.lax.scan()loop rather than a custom, unroll-ablejax.lax.while_loop()that monitors convergence. In that case the error is only computed at the last iteration.anderson (
Optional[AndersonAcceleration]) –AndersonAccelerationinstance.implicit_diff (
Optional[ImplicitDiff]) –ImplicitDiffinstance used to parameterize the linear solvers used in implicit differentiation. Tha algorithm uses unrolling of iterations ifNone.parallel_dual_updates (
bool) – updates potentials or scalings in parallel ifTrue, sequentially (in Gauss-Seidel fashion) ifFalse.recenter_potentials (
bool) – Whether to re-center the dual potentials. If the problem is balanced, thefpotential is zero-centered for numerical stability. Otherwise, use the approach of [Sejourne et al., 2022] to achieve faster convergence. Only used whenlse_mode = Trueandtau_a < 1andtau_b < 1.use_danskin (
Optional[bool]) – whenTrue, it is assumed the entropy-regularized optimal transport cost is evaluated using dual Kantorovich potentials that are frozen, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid when the algorithm has converged with a low tolerance.initializer (
Optional[SinkhornInitializer]) – method to compute the initial potentials/scalings. Seelinearfor more information.progress_fn (
Optional[Callable[[Tuple[ndarray,ndarray,ndarray,SinkhornState]],None]]) – callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. Seedefault_progress_fn()for a basic implementation.
Methods
init_state(ot_prob, init)Return the initial state of the loop.
kernel_step(ot_prob, state, iteration)Sinkhorn multiplicative update.
lse_step(ot_prob, state, iteration)Sinkhorn LSE update.
one_iteration(ot_prob, state, iteration, ...)Carries out one Sinkhorn iteration.
output_from_state(ot_prob, state)Create an output from a loop state.
Attributes
Powers used to compute the p-norm between marginal/target.
Upper bound on number of times inner_iterations are carried out.