ImplicitDiff.solve(gr, ot_prob, f, g, lse_mode)[source]#

Apply minus inverse of [hessian reg_ot_cost w.r.t. f, g].

This function is used to carry out implicit differentiation of sinkhorn outputs, notably optimal potentials f and g. That differentiation requires solving a linear system, using (and inverting) the Jacobian of (preconditioned) first-order conditions w.r.t. the reg-OT problem.

Given a precondition_fun, written here for short as \(h\), the first order conditions for the dual energy

\[E(K, \epsilon, a, b, f, g) :=- <a,\phi_a^{*}(-f)> + <b, \phi_b^{*}(-g)> - \langle\exp^{f/\epsilon}, K \exp^{g/\epsilon}>\]

form the basis of the Sinkhorn algorithm. To differentiate optimal solutions to that problem, we exploit the fact that \(h(\nabla E = 0)\) and differentiate that identity to recover variations (Jacobians) of optimal solutions \(f^\star, g^\star$\) as a function of changes in the inputs. The Jacobian of \(h(\nabla_{f,g} E = 0)\) is a linear operator which, if it were to be instantiated as a matrix, would be of size \((n+m) \times (n+m)\). When \(h\) is the identity, that matrix is the Hessian of \(E\), is symmetric and negative-definite (\(E\) is concave) and is structured as \([A, B; B^T, D]\). More generally, for other functions \(h\), the Jacobian of these preconditioned first order conditions is no longer symmetric (except if a==b), and has now a structure as \([A, B; C, D]\). That system can be still inverted more generic solvers. By default, \(h = \epsilon \log\), as proposed in [Cuturi et al., 2020].

In both cases \(A\) and \(D\) are diagonal matrices, equal to the row and column marginals respectively, multiplied by the derivatives of \(h\) evaluated at those marginals, corrected (if handling the unbalanced case) by the second derivative of the part of the objective that ties potentials to the marginals (terms in phi_star). When \(h\) is the identity, \(B\) and \(B^T\) are equal respectively to the OT matrix and its transpose, i.e. \(n \times m\) and \(m \times n\) matrices. When \(h\) is not the identity, \(B\) (resp. \(C\)) is equal to the OT matrix (resp. its transpose), rescaled on the left by the application elementwise of \(h'\) to the row (respectively column) marginal sum of the transport.

Note that we take great care in not instantiating these transport matrices, to rely instead on calls to the app_transport method from the Geometry object geom (which will either use potentials or scalings, depending on lse_mode)

The Jacobian’s diagonal + off-diagonal blocks structure allows to exploit Schur complements. Depending on the sizes involved, it is better to instantiate the Schur complement of the first or of the second diagonal block.

These linear systems are solved using the user-defined solver, using by default lineax solvers when available, or falling back on jax when not.

  • gr (Tuple[Array, Array]) – 2-tuple, (vector of size n, vector of size m).

  • ot_prob (LinearProblem) – the instantiation of the regularized transport problem.

  • f (Array) – potential, w.r.t marginal a.

  • g (Array) – potential, w.r.t marginal b.

  • lse_mode (bool) – bool, log-sum-exp mode if True, kernel else.

Return type:



A tuple of two vectors, of the same size as gr.