ott.core.sinkhorn.sinkhorn(geom, a=None, b=None, tau_a=1.0, tau_b=1.0, init_dual_a=None, init_dual_b=None, **kwargs)[source]#

Jax version of Sinkhorn’s algorithm.

Solves regularized OT problem using Sinkhorn iterations.

The Sinkhorn algorithm is a fixed point iteration that solves a regularized optimal transport (reg-OT) problem between two measures. The optimization variables are a pair of vectors (called potentials, or scalings when parameterized as exponentials of the former). Calling this function returns therefore a pair of optimal vectors. In addition to these, sinkhorn also returns the objective value achieved by these optimal vectors; a vector of size max_iterations/inner_terations that records the vector of values recorded to monitor convergence, throughout the execution of the algorithm (padded with -1 if convergence happens before), as well as a boolean to signify whether the algorithm has converged within the number of iterations specified by the user.

The reg-OT problem is specified by two measures, of respective sizes n and m. From the viewpoint of the sinkhorn function, these two measures are only seen through a triplet (geom, a, b), where geom is a Geometry object, and a and b are weight vectors of respective sizes n and m. Starting from two initial values for those potentials or scalings (both can be defined by the user by passing value in init_dual_a or init_dual_b), the Sinkhorn algorithm will use elementary operations that are carried out by the geom object.

Some maths:

Given a geometry geom, which provides a cost matrix \(C\) with its regularization parameter \(\epsilon\), (resp. a kernel matrix \(K\)) the reg-OT problem consists in finding two vectors f, g of size n, m that maximize the following criterion.

\[\arg\max_{f, g}{- <a, \phi_a^{*}(-f)> - <b, \phi_b^{*}(-g)> - \epsilon <e^{f/\epsilon}, e^{-C/\epsilon} e^{-g/\epsilon}}>\]

where \(\phi_a(z) = \rho_a z(\log z - 1)\) is a scaled entropy, and \(\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}\) its Legendre transform.

That problem can also be written, instead, using positive scaling vectors u, v of size n, m, handled with the kernel \(K:=e^{-C/\epsilon}\),

\[\arg\max_{u, v >0} - <a,\phi_a^{*}(-\epsilon\log u)> + <b, \phi_b^{*}(-\epsilon\log v)> - <u, K v>\]

Both of these problems corresponds, in their primal formulation, to solving the unbalanced optimal transport problem with a variable matrix P of size n x m:

\[\arg\min_{P>0} <P,C> -\epsilon \text{KL}(P | ab^T) + \rho_a \text{KL}(P1 | a) + \rho_b \text{KL}(P^T1 | b)\]

where \(KL\) is the generalized Kullback-Leibler divergence.

The very same primal problem can also be written using a kernel \(K\) instead of a cost \(C\) as well:

\[\arg\min_{P} \epsilon KL(P|K) + \rho_a \text{KL}(P1 | a) + \rho_b \text{KL}(P^T1 | b)\]

The original OT problem taught in linear programming courses is recovered by using the formulation above relying on the cost \(C\), and letting \(\epsilon \rightarrow 0\), and \(\rho_a, \rho_b \rightarrow \infty\). In that case the entropy disappears, whereas the \(KL\) regularizations above become constraints on the marginals of \(P\): This results in a standard min cost flow problem. This problem is not handled for now in this toolbox, which focuses exclusively on the case \(\epsilon > 0\).

The balanced regularized OT problem is recovered for finite \(\epsilon > 0\) but letting \(\rho_a, \rho_b \rightarrow \infty\). This problem can be shown to be equivalent to a matrix scaling problem, which can be solved using the Sinkhorn fixed-point algorithm. To handle the case \(\rho_a, \rho_b \rightarrow \infty\), the sinkhorn function uses parameters tau_a := \(\rho_a / (\epsilon + \rho_a)\) and tau_b := \(\rho_b / (\epsilon + \rho_b)\) instead. Setting either of these parameters to 1 corresponds to setting the corresponding \(\rho_a, \rho_b\) to \(\infty\).

The Sinkhorn algorithm solves the reg-OT problem by seeking optimal f, g potentials (or alternatively their parametrization as positive scalings u, v), rather than solving the primal problem in \(P\). This is mostly for efficiency (potentials and scalings have a n + m memory footprint, rather than n m required to store P). This is also because both problems are, in fact, equivalent, since the optimal transport \(P^*\) can be recovered from optimal potentials \(f^*\), \(g^*\) or scalings \(u^*\), \(v^*\), using the geometry’s cost or kernel matrix respectively:

\[P^* = \exp\left(\frac{f^*\mathbf{1}_m^T + \mathbf{1}_n g^{*T} - C}{\epsilon}\right) \text{ or } P^* = \text{diag}(u^*) K \text{diag}(v^*)\]

By default, the Sinkhorn algorithm solves this dual problem in f, g or u, v using block coordinate ascent, i.e. devising an update for each f and g (resp. u and v) that cancels their respective gradients, one at a time. These two iterations are repeated inner_iterations times, after which the norm of these gradients will be evaluated and compared with the threshold value. The iterations are then repeated as long as that error exceeds threshold.

Note on Sinkhorn updates:

The boolean flag lse_mode sets whether the algorithm is run in either:

  • log-sum-exp mode (lse_mode=True), in which case it is directly defined in terms of updates to f and g, using log-sum-exp computations. This requires access to the cost matrix \(C\), as it is stored, or possibly computed on the fly by geom.

  • kernel mode (lse_mode=False), in which case it will require access to a matrix vector multiplication operator \(z \rightarrow K z\), where \(K\) is either instantiated from \(C\) as \(\exp(-C/\epsilon)\), or provided directly. In that case, rather than optimizing on \(f\) and \(g\), it is more convenient to optimize on their so called scaling formulations, \(u := \exp(f / \epsilon)\) and \(v := \exp(g / \epsilon)\). While faster (applying matrices is faster than applying lse repeatedly over lines), this mode is also less stable numerically, notably for smaller \(\epsilon\).

In the source code, the variables f_u or g_v can be either regarded as potentials (real) or scalings (positive) vectors, depending on the choice of lse_mode by the user. Once optimization is carried out, we only return dual variables in potential form, i.e. f and g.

In addition to standard Sinkhorn updates, the user can also use heavy-ball type updates using a momentum parameter in ]0,2[. We also implement a strategy that tries to set that parameter adaptively ater chg_momentum_from iterations, as a function of progress in the error, as proposed in the literature.

Another upgrade to the standard Sinkhorn updates provided to the users lies in using Anderson acceleration. This can be parameterized by setting the otherwise null anderson to a positive integer. When selected,the algorithm will recompute, every refresh_anderson_frequency (set by default to 1) an extrapolation of the most recently computed anderson iterates. When using that option, notice that differentiation (if required) can only be carried out using implicit differentiation, and that all momentum related parameters are ignored.

The parallel_dual_updates flag is set to False by default. In that setting, g_v is first updated using the latest values for f_u and g_v, before proceeding to update f_u using that new value for g_v. When the flag is set to True, both f_u and g_v are updated simultaneously. Note that setting that choice to True requires using some form of averaging (e.g. momentum=0.5). Without this, and on its own parallel_dual_updates won’t work.


The optimal solutions f and g and the optimal objective (reg_ot_cost) outputted by the Sinkhorn algorithm can be differentiated w.r.t. relevant inputs geom, a and b using, by default, implicit differentiation of the optimality conditions (implicit_differentiation set to True). This choice has two consequences.

  • The termination criterion used to stop Sinkhorn (cancellation of gradient of objective w.r.t. f_u and g_v) is used to differentiate f and g, given a change in the inputs. These changes are computed by solving a linear system. The arguments starting with implicit_solver_* allow to define the linear solver that is used, and to control for two types or regularization (we have observed that, depending on the architecture, linear solves may require higher ridge parameters to remain stable). The optimality conditions in Sinkhorn can be analyzed as satisfying a z=z' condition, which are then differentiated. It might be beneficial (e.g. as in to use a preconditionning function precondition_fun to differentiate instead h(z)=h(z').

  • The objective reg_ot_cost returned by Sinkhon uses the so-called enveloppe (or Danskin’s) theorem. In that case, because it is assumed that the gradients of the dual variables f_u and g_v w.r.t. dual objective are zero (reflecting the fact that they are optimal), small variations in f_u and g_v due to changes in inputs (such as geom, a and b) are considered negligible. As a result, stop_gradient is applied on dual variables f_u and g_v when evaluating the reg_ot_cost objective. Note that this approach is invalid when computing higher order derivatives. In that case the use_danskin flag must be set to False.

An alternative yet more costly way to differentiate the outputs of the Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation of the Sinkhorn loop. This is possible because Sinkhorn iterations are wrapped in a custom fixed point iteration loop, defined in fixed_point_loop, rather than a standard while loop. This is to ensure the end result of this fixed point loop can also be differentiated, if needed, using standard JAX operations. To ensure backprop differentiability, the fixed_point_loop.fixpoint_iter_backprop loop does checkpointing of state variables (here f_u and g_v) every inner_iterations, and backpropagates automatically, block by block, through blocks of inner_iterations at a time.


  • The Sinkhorn algorithm may not converge within the maximum number of iterations for possibly several reasons:

    1. the regularizer (defined as epsilon in the geometry geom object) is too small. Consider either switching to lse_mode=True (at the price of a slower execution), increasing epsilon, or, alternatively, if you are unable or unwilling to increase epsilon, either increase max_iterations or threshold.

    2. the probability weights a and b do not have the same total mass, while using a balanced (tau_a=tau_b=1.0) setup. Consider either normalizing a and b, or set either tau_a and/or tau_b<1.0.

    3. OOMs issues may arise when storing either cost or kernel matrices that are too large in geom. In the case where, the geom geometry is a PointCloud, some of these issues might be solved by setting the online flag to True. This will trigger a recomputation on the fly of the cost/kernel matrix.

  • The weight vectors a and b can be passed on with coordinates that have zero weight. This is then handled by relying on simple arithmetic for inf values that will likely arise (due to \(log(0)\) when lse_mode is True, or divisions by zero when lse_mode is False). Whenever that arithmetic is likely to produce NaN values (due to -inf * 0, or -inf - -inf) in the forward pass, we use jnp.where conditional statements to carry inf rather than NaN values. In the reverse mode differentiation, the inputs corresponding to these 0 weights (a location x, or a row in the corresponding cost/kernel matrix), and the weight itself will have NaN gradient values. This is reflects that these gradients are undefined, since these points were not considered in the optimization and have therefore no impact on the output.

  • geom (Geometry) – a Geometry object.

  • a (Optional[ndarray]) – [num_a,] or [batch, num_a] weights.

  • b (Optional[ndarray]) – [num_b,] or [batch, num_b] weights.

  • tau_a (float) – ratio rho/(rho+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation.

  • tau_b (float) – ratio rho/(rho+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation.

  • init_dual_a (Optional[ndarray]) – optional initialization for potentials/scalings w.r.t. first marginal (a) of reg-OT problem.

  • init_dual_b (Optional[ndarray]) – optional initialization for potentials/scalings w.r.t. second marginal (b) of reg-OT problem.

  • threshold – tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution when either or both tau_a and tau_b are 1.0 (balanced or semi-balanced problem), or the relative change between two successive solutions in the unbalanced case.

  • norm_error – power used to define p-norm of error for marginal/target.

  • inner_iterations – the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead.

  • min_iterations – the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.

  • max_iterations – the maximum number of Sinkhorn iterations. If max_iterations is equal to min_iterations, sinkhorn iterations are run by default using a jax.lax.scan loop rather than a custom, unroll-able jax.lax.while_loop that monitors convergence. In that case the error is not monitored and the converged flag will return False as a consequence.

  • momentum – a float in [0,2].

  • chg_momentum_from – if positive, momentum is recomputed using the adaptive rule provided in after that number of iterations.

  • anderson_acceleration – int, if 0 (default), no acceleration. If positive, use Anderson acceleration on the dual sinkhorn (in log/potential form), as described in and advocated in, with a memory of size equal to anderson_acceleration. In that case, differentiation is necessarily handled implicitly (implicit_differentiation is set to True) and all momentum related parameters are ignored.

  • refresh_anderson_frequency – int, when using anderson_acceleration, recompute direction periodically every int sinkhorn iterations.

  • lse_modeTrue for log-sum-exp computations, False for kernel multiplication.

  • implicit_differentiationTrue if using implicit differentiation, False if unrolling Sinkhorn iterations.

  • linear_solve_kwargs – parametrization of linear solver when using implicit differentiation. Arguments currently accepted appear in the optional arguments of apply_inv_hessian, namely linear_solver_fun, a Callable that specifies the linear solver, as well as ridge_kernel and ridge_identity, to be added to enforce stability of linear solve.

  • parallel_dual_updates – updates potentials or scalings in parallel if True, sequentially (in Gauss-Seidel fashion) if False.

  • use_danskin – when True, it is assumed the entropy regularized cost is is evaluated using optimal potentials that are freezed, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with implicit_differentiation) when the algorithm has converged with a low tolerance.

  • jit – if True, automatically jits the function upon first call. Should be set to False when used in a function that is jitted by the user, or when computing gradients (in which case the gradient function should be jitted by the user)

  • kwargs (Any) – Additional keyword arguments (see above).


a SinkhornOutput named tuple. The tuple contains two optimal potential vectors f and g, the objective reg_ot_cost evaluated at those solutions, an array of errors to monitor convergence every inner_iterations and a flag converged that is True if the algorithm has converged within the number of iterations that was predefined by the user.