ott.solvers.linear.sinkhorn.Sinkhorn

Contents

ott.solvers.linear.sinkhorn.Sinkhorn#

class ott.solvers.linear.sinkhorn.Sinkhorn(lse_mode=True, threshold=0.001, norm_error=1, inner_iterations=10, min_iterations=0, max_iterations=2000, momentum=None, anderson=None, parallel_dual_updates=False, recenter_potentials=False, use_danskin=None, implicit_diff=ImplicitDiff(solver=None, solver_kwargs=None, symmetric=False, precondition_fun=None), initializer='default', progress_fn=None, kwargs_init=None)[source]#

Sinkhorn solver.

The Sinkhorn algorithm is a fixed point iteration that solves a regularized optimal transport (reg-OT) problem between two measures. The optimization variables are a pair of vectors (called potentials, or scalings when parameterized as exponential of the former). Calling this function returns therefore a pair of optimal vectors. In addition to these, it also returns the objective value achieved by these optimal vectors; a vector of size max_iterations/inner_iterations that records the vector of values recorded to monitor convergence, throughout the execution of the algorithm (padded with -1 if convergence happens before), as well as a boolean to signify whether the algorithm has converged within the number of iterations specified by the user.

The reg-OT problem is specified by two measures, of respective sizes n and m. From the viewpoint of the sinkhorn function, these two measures are only seen through a triplet (geom, a, b), where geom is a Geometry object, and a and b are weight vectors of respective sizes n and m. Starting from two initial values for those potentials or scalings (both can be defined by the user by passing value in init_dual_a or init_dual_b), the Sinkhorn algorithm will use elementary operations that are carried out by the geom object.

Math:

Given a geometry geom, which provides a cost matrix \(C\) with its regularization parameter \(\varepsilon\), (or a kernel matrix \(K\)) the reg-OT problem consists in finding two vectors f, g of size n, m that maximize the following criterion.

\[\arg\max_{f, g}{- \langle a, \phi_a^{*}(-f) \rangle - \langle b, \phi_b^{*}(-g) \rangle - \varepsilon \langle e^{f/\varepsilon}, e^{-C/\varepsilon} e^{-g/\varepsilon}} \rangle\]

where \(\phi_a(z) = \rho_a z(\log z - 1)\) is a scaled entropy, and \(\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}\), its Legendre transform.

That problem can also be written, instead, using positive scaling vectors u, v of size n, m, handled with the kernel \(K := e^{-C/\varepsilon}\),

\[\arg\max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle + \langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - \langle u, K v \rangle\]

Both of these problems corresponds, in their primal formulation, to solving the unbalanced optimal transport problem with a variable matrix \(P\) of size n x m:

\[\arg\min_{P>0} \langle P,C \rangle -\varepsilon \text{KL}(P | ab^T) + \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n | b)\]

where \(KL\) is the generalized Kullback-Leibler divergence.

The very same primal problem can also be written using a kernel \(K\) instead of a cost \(C\) as well:

\[\arg\min_{P} \varepsilon \text{KL}(P|K) + \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n | b)\]

The original OT problem taught in linear programming courses is recovered by using the formulation above relying on the cost \(C\), and letting \(\varepsilon \rightarrow 0\), and \(\rho_a, \rho_b \rightarrow \infty\). In that case the entropy disappears, whereas the \(KL\) regularization above become constraints on the marginals of \(P\): This results in a standard min cost flow problem. This problem is not handled for now in this toolbox, which focuses exclusively on the case \(\varepsilon > 0\).

The balanced regularized OT problem is recovered for finite \(\varepsilon > 0\) but letting \(\rho_a, \rho_b \rightarrow \infty\). This problem can be shown to be equivalent to a matrix scaling problem, which can be solved using the Sinkhorn fixed-point algorithm. To handle the case \(\rho_a, \rho_b \rightarrow \infty\), the sinkhorn function uses parameters tau_a and tau_b equal respectively to \(\rho_a /(\varepsilon + \rho_a)\) and \(\rho_b / (\varepsilon + \rho_b)\) instead. Setting either of these parameters to 1 corresponds to setting the corresponding \(\rho_a, \rho_b\) to \(\infty\).

The Sinkhorn algorithm solves the reg-OT problem by seeking optimal \(f\), \(g\) potentials (or alternatively their parameterization as positive scaling vectors \(u\), \(v\)), rather than solving the primal problem in \(P\). This is mostly for efficiency (potentials and scalings have a n + m memory footprint, rather than n m required to store P). This is also because both problems are, in fact, equivalent, since the optimal transport \(P^{\star}\) can be recovered from optimal potentials \(f^{\star}\), \(g^{\star}\) or scaling \(u^{\star}\), \(v^{\star}\), using the geometry’s cost or kernel matrix respectively:

\[P^{\star} = \exp\left(\frac{f^{\star}\mathbf{1}_m^T + \mathbf{1}_n g^{*T}- C}{\varepsilon}\right) \text{ or } P^{\star} = \text{diag}(u^{\star}) K \text{diag}(v^{\star})\]

By default, the Sinkhorn algorithm solves this dual problem in \(f, g\) or \(u, v\) using block coordinate ascent, i.e. devising an update for each \(f\) and \(g\) (resp. \(u\) and \(v\)) that cancels their respective gradients, one at a time. These two iterations are repeated inner_iterations times, after which the norm of these gradients will be evaluated and compared with the threshold value. The iterations are then repeated as long as that error exceeds threshold.

Note on Sinkhorn updates:

The boolean flag lse_mode sets whether the algorithm is run in either:

  • log-sum-exp mode (lse_mode=True), in which case it is directly defined in terms of updates to f and g, using log-sum-exp computations. This requires access to the cost matrix \(C\), as it is stored, or possibly computed on the fly by geom.

  • kernel mode (lse_mode=False), in which case it will require access to a matrix vector multiplication operator \(z \rightarrow K z\), where \(K\) is either instantiated from \(C\) as \(\exp(-C/\varepsilon)\), or provided directly. In that case, rather than optimizing on \(f\) and \(g\), it is more convenient to optimize on their so called scaling formulations, \(u := \exp(f / \varepsilon)\) and \(v := \exp(g / \varepsilon)\). While faster (applying matrices is faster than applying lse repeatedly over lines), this mode is also less stable numerically, notably for smaller \(\varepsilon\).

In the source code, the variables f_u or g_v can be either regarded as potentials (real) or scalings (positive) vectors, depending on the choice of lse_mode by the user. Once optimization is carried out, we only return dual variables in potential form, i.e. f and g.

In addition to standard Sinkhorn updates, the user can also use heavy-ball type updates using a momentum parameter in ]0,2[. We also implement a strategy that tries to set that parameter adaptively at chg_momentum_from iterations, as a function of progress in the error, as proposed in the literature.

Another upgrade to the standard Sinkhorn updates provided to the users lies in using Anderson acceleration. This can be parameterized by setting the otherwise null anderson to a positive integer. When selected,the algorithm will recompute, every refresh_anderson_frequency (set by default to 1) an extrapolation of the most recently computed anderson iterates. When using that option, notice that differentiation (if required) can only be carried out using implicit differentiation, and that all momentum related parameters are ignored.

The parallel_dual_updates flag is set to False by default. In that setting, g_v is first updated using the latest values for f_u and g_v, before proceeding to update f_u using that new value for g_v. When the flag is set to True, both f_u and g_v are updated simultaneously. Note that setting that choice to True requires using some form of averaging (e.g. momentum=0.5). Without this, and on its own parallel_dual_updates won’t work.

Differentiation:

The optimal solutions f and g and the optimal objective (reg_ot_cost) outputted by the Sinkhorn algorithm can be differentiated w.r.t. relevant inputs geom, a and b. In the default setting, implicit differentiation of the optimality conditions (implicit_diff not equal to None), this has two consequences, treating f and g differently from reg_ot_cost.

  • The termination criterion used to stop Sinkhorn (cancellation of gradient of objective w.r.t. f_u and g_v) is used to differentiate f and g, given a change in the inputs. These changes are computed by solving a linear system. The arguments starting with implicit_solver_* allow to define the linear solver that is used, and to control for two types or regularization (we have observed that, depending on the architecture, linear solves may require higher ridge parameters to remain stable). The optimality conditions in Sinkhorn can be analyzed as satisfying a z=z' condition, which are then differentiated. It might be beneficial (e.g., as in [Cuturi et al., 2020]) to use a preconditioning function precondition_fun to differentiate instead h(z) = h(z').

  • The objective reg_ot_cost returned by Sinkhorn uses the so-called envelope (or Danskin’s) theorem. In that case, because it is assumed that the gradients of the dual variables f_u and g_v w.r.t. dual objective are zero (reflecting the fact that they are optimal), small variations in f_u and g_v due to changes in inputs (such as geom, a and b) are considered negligible. As a result, stop_gradient is applied on dual variables f_u and g_v when evaluating the reg_ot_cost objective. Note that this approach is invalid when computing higher order derivatives. In that case the use_danskin flag must be set to False.

An alternative yet more costly way to differentiate the outputs of the Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation of the Sinkhorn loop. This is possible because Sinkhorn iterations are wrapped in a custom fixed point iteration loop, defined in fixed_point_loop, rather than a standard while loop. This is to ensure the end result of this fixed point loop can also be differentiated, if needed, using standard JAX operations. To ensure differentiability, the fixed_point_loop.fixpoint_iter_backprop loop does checkpointing of state variables (here f_u and g_v) every inner_iterations, and backpropagates automatically, block by block, through blocks of inner_iterations at a time.

Note

  • The Sinkhorn algorithm may not converge within the maximum number of iterations for possibly several reasons:

    1. the regularizer (defined as epsilon in the geometry geom object) is too small. Consider either switching to lse_mode=True (at the price of a slower execution), increasing epsilon, or, alternatively, if you are unable or unwilling to increase epsilon, either increase max_iterations or threshold.

    2. the probability weights a and b do not have the same total mass, while using a balanced (tau_a=tau_b=1.0) setup. Consider either normalizing a and b, or set either tau_a and/or tau_b<1.0.

    3. OOMs issues may arise when storing either cost or kernel matrices that are too large in geom. In the case where, the geom geometry is a PointCloud, some of these issues might be solved by setting the online flag to True. This will trigger a re-computation on the fly of the cost/kernel matrix.

  • The weight vectors a and b can be passed on with coordinates that have zero weight. This is then handled by relying on simple arithmetic for inf values that will likely arise (due to \(\log 0\) when lse_mode is True, or divisions by zero when lse_mode is False). Whenever that arithmetic is likely to produce NaN values (due to -inf * 0, or -inf - -inf) in the forward pass, we use jnp.where conditional statements to carry inf rather than NaN values. In the reverse mode differentiation, the inputs corresponding to these 0 weights (a location x, or a row in the corresponding cost/kernel matrix), and the weight itself will have NaN gradient values. This is reflects that these gradients are undefined, since these points were not considered in the optimization and have therefore no impact on the output.

Parameters:
  • lse_mode (bool) – True for log-sum-exp computations, False for kernel multiplication.

  • threshold (float) – tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution when either or both tau_a and tau_b are 1.0 (balanced or semi-balanced problem), or the relative change between two successive solutions in the unbalanced case.

  • norm_error (int) – power used to define p-norm of error for marginal/target.

  • inner_iterations (int) – the Sinkhorn error is not recomputed at each iteration but every inner_iterations instead.

  • min_iterations (int) – the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.

  • max_iterations (int) – the maximum number of Sinkhorn iterations. If max_iterations is equal to min_iterations, Sinkhorn iterations are run by default using a jax.lax.scan() loop rather than a custom, unroll-able jax.lax.while_loop() that monitors convergence. In that case the error is not monitored and the converged flag will return False as a consequence.

  • momentum (Optional[Momentum]) – Momentum instance.

  • anderson (Optional[AndersonAcceleration]) – AndersonAcceleration instance.

  • implicit_diff (Optional[ImplicitDiff]) – instance used to solve implicit differentiation. Unrolls iterations if None.

  • parallel_dual_updates (bool) – updates potentials or scalings in parallel if True, sequentially (in Gauss-Seidel fashion) if False.

  • recenter_potentials (bool) – Whether to re-center the dual potentials. If the problem is balanced, the f potential is zero-centered for numerical stability. Otherwise, use the approach of [Sejourne et al., 2022] to achieve faster convergence. Only used when lse_mode = True and tau_a < 1 and tau_b < 1.

  • use_danskin (Optional[bool]) – when True, it is assumed the entropy regularized cost is evaluated using optimal potentials that are frozen, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with implicit_differentiation) when the algorithm has converged with a low tolerance.

  • initializer (Union[Literal['default', 'gaussian', 'sorting', 'subsample'], SinkhornInitializer]) – how to compute the initial potentials/scalings. This refers to a few possible classes implemented following the template in SinkhornInitializer.

  • progress_fn (Optional[Callable[[Tuple[ndarray, ndarray, ndarray, SinkhornState]], None]]) – callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. See default_progress_fn() for a basic implementation.

  • kwargs_init (Optional[Mapping[str, Any]]) – keyword arguments when creating the initializer.

Methods

create_initializer()

rtype:

SinkhornInitializer

init_state(ot_prob, init)

Return the initial state of the loop.

kernel_step(ot_prob, state, iteration)

Sinkhorn multiplicative update.

lse_step(ot_prob, state, iteration)

Sinkhorn LSE update.

one_iteration(ot_prob, state, iteration, ...)

Carries out one Sinkhorn iteration.

output_from_state(ot_prob, state)

Create an output from a loop state.

Attributes

norm_error

Powers used to compute the p-norm between marginal/target.

outer_iterations

Upper bound on number of times inner_iterations are carried out.