ott.solvers.linear.sinkhorn.Sinkhorn#
- class ott.solvers.linear.sinkhorn.Sinkhorn(lse_mode=True, threshold=0.001, norm_error=1, inner_iterations=10, min_iterations=0, max_iterations=2000, momentum=None, anderson=None, parallel_dual_updates=False, recenter_potentials=False, use_danskin=None, implicit_diff=ImplicitDiff(solver=None, solver_kwargs=None, symmetric=False, precondition_fun=None), initializer='default', progress_fn=None, kwargs_init=None)[source]#
Sinkhorn solver.
The Sinkhorn algorithm is a fixed point iteration that solves a regularized optimal transport (reg-OT) problem between two measures. The optimization variables are a pair of vectors (called potentials, or scalings when parameterized as exponential of the former). Calling this function returns therefore a pair of optimal vectors. In addition to these, it also returns the objective value achieved by these optimal vectors; a vector of size
max_iterations/inner_iterations
that records the vector of values recorded to monitor convergence, throughout the execution of the algorithm (padded with -1 if convergence happens before), as well as a boolean to signify whether the algorithm has converged within the number of iterations specified by the user.The reg-OT problem is specified by two measures, of respective sizes
n
andm
. From the viewpoint of thesinkhorn
function, these two measures are only seen through a triplet (geom
,a
,b
), wheregeom
is aGeometry
object, anda
andb
are weight vectors of respective sizesn
andm
. Starting from two initial values for those potentials or scalings (both can be defined by the user by passing value ininit_dual_a
orinit_dual_b
), the Sinkhorn algorithm will use elementary operations that are carried out by thegeom
object.- Math:
Given a geometry
geom
, which provides a cost matrix \(C\) with its regularization parameter \(\varepsilon\), (or a kernel matrix \(K\)) the reg-OT problem consists in finding two vectors f, g of sizen
,m
that maximize the following criterion.\[\arg\max_{f, g}{- \langle a, \phi_a^{*}(-f) \rangle - \langle b, \phi_b^{*}(-g) \rangle - \varepsilon \langle e^{f/\varepsilon}, e^{-C/\varepsilon} e^{-g/\varepsilon}} \rangle\]where \(\phi_a(z) = \rho_a z(\log z - 1)\) is a scaled entropy, and \(\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}\), its Legendre transform.
That problem can also be written, instead, using positive scaling vectors u, v of size
n
,m
, handled with the kernel \(K := e^{-C/\varepsilon}\),\[\arg\max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle + \langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - \langle u, K v \rangle\]Both of these problems corresponds, in their primal formulation, to solving the unbalanced optimal transport problem with a variable matrix \(P\) of size
n
xm
:\[\arg\min_{P>0} \langle P,C \rangle -\varepsilon \text{KL}(P | ab^T) + \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n | b)\]where \(KL\) is the generalized Kullback-Leibler divergence.
The very same primal problem can also be written using a kernel \(K\) instead of a cost \(C\) as well:
\[\arg\min_{P} \varepsilon \text{KL}(P|K) + \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n | b)\]The original OT problem taught in linear programming courses is recovered by using the formulation above relying on the cost \(C\), and letting \(\varepsilon \rightarrow 0\), and \(\rho_a, \rho_b \rightarrow \infty\). In that case the entropy disappears, whereas the \(KL\) regularization above become constraints on the marginals of \(P\): This results in a standard min cost flow problem. This problem is not handled for now in this toolbox, which focuses exclusively on the case \(\varepsilon > 0\).
The balanced regularized OT problem is recovered for finite \(\varepsilon > 0\) but letting \(\rho_a, \rho_b \rightarrow \infty\). This problem can be shown to be equivalent to a matrix scaling problem, which can be solved using the Sinkhorn fixed-point algorithm. To handle the case \(\rho_a, \rho_b \rightarrow \infty\), the
sinkhorn
function uses parameterstau_a
andtau_b
equal respectively to \(\rho_a /(\varepsilon + \rho_a)\) and \(\rho_b / (\varepsilon + \rho_b)\) instead. Setting either of these parameters to 1 corresponds to setting the corresponding \(\rho_a, \rho_b\) to \(\infty\).The Sinkhorn algorithm solves the reg-OT problem by seeking optimal \(f\), \(g\) potentials (or alternatively their parameterization as positive scaling vectors \(u\), \(v\)), rather than solving the primal problem in \(P\). This is mostly for efficiency (potentials and scalings have a
n + m
memory footprint, rather thann m
required to store P). This is also because both problems are, in fact, equivalent, since the optimal transport \(P^{\star}\) can be recovered from optimal potentials \(f^{\star}\), \(g^{\star}\) or scaling \(u^{\star}\), \(v^{\star}\), using the geometry’s cost or kernel matrix respectively:\[P^{\star} = \exp\left(\frac{f^{\star}\mathbf{1}_m^T + \mathbf{1}_n g^{*T}- C}{\varepsilon}\right) \text{ or } P^{\star} = \text{diag}(u^{\star}) K \text{diag}(v^{\star})\]By default, the Sinkhorn algorithm solves this dual problem in \(f, g\) or \(u, v\) using block coordinate ascent, i.e. devising an update for each \(f\) and \(g\) (resp. \(u\) and \(v\)) that cancels their respective gradients, one at a time. These two iterations are repeated
inner_iterations
times, after which the norm of these gradients will be evaluated and compared with thethreshold
value. The iterations are then repeated as long as that error exceedsthreshold
.- Note on Sinkhorn updates:
The boolean flag
lse_mode
sets whether the algorithm is run in either:log-sum-exp mode (
lse_mode=True
), in which case it is directly defined in terms of updates to f and g, using log-sum-exp computations. This requires access to the cost matrix \(C\), as it is stored, or possibly computed on the fly bygeom
.kernel mode (
lse_mode=False
), in which case it will require access to a matrix vector multiplication operator \(z \rightarrow K z\), where \(K\) is either instantiated from \(C\) as \(\exp(-C/\varepsilon)\), or provided directly. In that case, rather than optimizing on \(f\) and \(g\), it is more convenient to optimize on their so called scaling formulations, \(u := \exp(f / \varepsilon)\) and \(v := \exp(g / \varepsilon)\). While faster (applying matrices is faster than applyinglse
repeatedly over lines), this mode is also less stable numerically, notably for smaller \(\varepsilon\).
In the source code, the variables
f_u
org_v
can be either regarded as potentials (real) or scalings (positive) vectors, depending on the choice oflse_mode
by the user. Once optimization is carried out, we only return dual variables in potential form, i.e.f
andg
.In addition to standard Sinkhorn updates, the user can also use heavy-ball type updates using a
momentum
parameter in ]0,2[. We also implement a strategy that tries to set that parameter adaptively atchg_momentum_from
iterations, as a function of progress in the error, as proposed in the literature.Another upgrade to the standard Sinkhorn updates provided to the users lies in using Anderson acceleration. This can be parameterized by setting the otherwise null
anderson
to a positive integer. When selected,the algorithm will recompute, everyrefresh_anderson_frequency
(set by default to 1) an extrapolation of the most recently computedanderson
iterates. When using that option, notice that differentiation (if required) can only be carried out using implicit differentiation, and that all momentum related parameters are ignored.The
parallel_dual_updates
flag is set toFalse
by default. In that setting,g_v
is first updated using the latest values forf_u
andg_v
, before proceeding to updatef_u
using that new value forg_v
. When the flag is set toTrue
, bothf_u
andg_v
are updated simultaneously. Note that setting that choice toTrue
requires using some form of averaging (e.g.momentum=0.5
). Without this, and on its ownparallel_dual_updates
won’t work.- Differentiation:
The optimal solutions
f
andg
and the optimal objective (reg_ot_cost
) outputted by the Sinkhorn algorithm can be differentiated w.r.t. relevant inputsgeom
,a
andb
. In the default setting, implicit differentiation of the optimality conditions (implicit_diff
not equal toNone
), this has two consequences, treatingf
andg
differently fromreg_ot_cost
.The termination criterion used to stop Sinkhorn (cancellation of gradient of objective w.r.t.
f_u
andg_v
) is used to differentiatef
andg
, given a change in the inputs. These changes are computed by solving a linear system. The arguments starting withimplicit_solver_*
allow to define the linear solver that is used, and to control for two types or regularization (we have observed that, depending on the architecture, linear solves may require higher ridge parameters to remain stable). The optimality conditions in Sinkhorn can be analyzed as satisfying az=z'
condition, which are then differentiated. It might be beneficial (e.g., as in [Cuturi et al., 2020]) to use a preconditioning functionprecondition_fun
to differentiate insteadh(z) = h(z')
.The objective
reg_ot_cost
returned by Sinkhorn uses the so-called envelope (or Danskin’s) theorem. In that case, because it is assumed that the gradients of the dual variablesf_u
andg_v
w.r.t. dual objective are zero (reflecting the fact that they are optimal), small variations inf_u
andg_v
due to changes in inputs (such asgeom
,a
andb
) are considered negligible. As a result,stop_gradient
is applied on dual variablesf_u
andg_v
when evaluating thereg_ot_cost
objective. Note that this approach is invalid when computing higher order derivatives. In that case theuse_danskin
flag must be set toFalse
.
An alternative yet more costly way to differentiate the outputs of the Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation of the Sinkhorn loop. This is possible because Sinkhorn iterations are wrapped in a custom fixed point iteration loop, defined in
fixed_point_loop
, rather than a standard while loop. This is to ensure the end result of this fixed point loop can also be differentiated, if needed, using standard JAX operations. To ensure differentiability, thefixed_point_loop.fixpoint_iter_backprop
loop does checkpointing of state variables (heref_u
andg_v
) everyinner_iterations
, and backpropagates automatically, block by block, through blocks ofinner_iterations
at a time.
Note
The Sinkhorn algorithm may not converge within the maximum number of iterations for possibly several reasons:
the regularizer (defined as
epsilon
in the geometrygeom
object) is too small. Consider either switching tolse_mode=True
(at the price of a slower execution), increasingepsilon
, or, alternatively, if you are unable or unwilling to increaseepsilon
, either increasemax_iterations
orthreshold
.the probability weights
a
andb
do not have the same total mass, while using a balanced (tau_a=tau_b=1.0
) setup. Consider either normalizinga
andb
, or set eithertau_a
and/ortau_b<1.0
.OOMs issues may arise when storing either cost or kernel matrices that are too large in
geom
. In the case where, thegeom
geometry is aPointCloud
, some of these issues might be solved by setting theonline
flag toTrue
. This will trigger a re-computation on the fly of the cost/kernel matrix.
The weight vectors
a
andb
can be passed on with coordinates that have zero weight. This is then handled by relying on simple arithmetic forinf
values that will likely arise (due to \(\log 0\) whenlse_mode
isTrue
, or divisions by zero whenlse_mode
isFalse
). Whenever that arithmetic is likely to produceNaN
values (due to-inf * 0
, or-inf - -inf
) in the forward pass, we usejnp.where
conditional statements to carryinf
rather thanNaN
values. In the reverse mode differentiation, the inputs corresponding to these 0 weights (a location x, or a row in the corresponding cost/kernel matrix), and the weight itself will haveNaN
gradient values. This is reflects that these gradients are undefined, since these points were not considered in the optimization and have therefore no impact on the output.
- Parameters
lse_mode (
bool
) –True
for log-sum-exp computations,False
for kernel multiplication.threshold (
float
) – tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution when either or both tau_a and tau_b are 1.0 (balanced or semi-balanced problem), or the relative change between two successive solutions in the unbalanced case.norm_error (
int
) – power used to define p-norm of error for marginal/target.inner_iterations (
int
) – the Sinkhorn error is not recomputed at each iteration but everyinner_iterations
instead.min_iterations (
int
) – the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.max_iterations (
int
) – the maximum number of Sinkhorn iterations. Ifmax_iterations
is equal tomin_iterations
, Sinkhorn iterations are run by default using ajax.lax.scan()
loop rather than a custom, unroll-ablejax.lax.while_loop()
that monitors convergence. In that case the error is not monitored and theconverged
flag will returnFalse
as a consequence.anderson (
Optional
[AndersonAcceleration
]) – AndersonAcceleration instance.implicit_diff (
Optional
[ImplicitDiff
]) – instance used to solve implicit differentiation. Unrolls iterations if None.parallel_dual_updates (
bool
) – updates potentials or scalings in parallel if True, sequentially (in Gauss-Seidel fashion) if False.recenter_potentials (
bool
) – Whether to re-center the dual potentials. If the problem is balanced, thef
potential is zero-centered for numerical stability. Otherwise, use the approach of [Sejourne et al., 2022] to achieve faster convergence. Only used whenlse_mode = True
andtau_a < 1
andtau_b < 1
.use_danskin (
Optional
[bool
]) – whenTrue
, it is assumed the entropy regularized cost is evaluated using optimal potentials that are frozen, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as withimplicit_differentiation
) when the algorithm has converged with a low tolerance.initializer (
Union
[Literal
['default'
,'gaussian'
,'sorting'
,'subsample'
],SinkhornInitializer
]) – how to compute the initial potentials/scalings. This refers to a few possible classes implemented following the template inSinkhornInitializer
.progress_fn (
Optional
[Callable
[[Tuple
[ndarray
,ndarray
,ndarray
,SinkhornState
]],None
]]) – callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. Seedefault_progress_fn()
for a basic implementation.kwargs_init (
Optional
[Mapping
[str
,Any
]]) – keyword arguments when creating the initializer.
Methods
- rtype
SinkhornInitializer
init_state
(ot_prob, init)Return the initial state of the loop.
kernel_step
(ot_prob, state, iteration)Sinkhorn multiplicative update.
lse_step
(ot_prob, state, iteration)Sinkhorn LSE update.
one_iteration
(ot_prob, state, iteration, ...)Carries out one Sinkhorn iteration.
output_from_state
(ot_prob, state)Create an output from a loop state.
Attributes
Powers used to compute the p-norm between marginal/target.
Upper bound on number of times inner_iterations are carried out.