ott.solvers.linear.implicit_differentiation.ImplicitDiff.solve#
- ImplicitDiff.solve(gr, ot_prob, f, g, lse_mode)[source]#
Apply minus inverse of Hessian of
reg_ot_cost
w.r.t. [f
,g
].This function is used to carry out implicit differentiation of the outputs of the Sinkhorn algorithm, notably dual Kantorovich potentials
f
andg
. That differentiation requires solving a linear system, using (and inverting) the Jacobian of (preconditioned) first-order conditions w.r.t. the reg-OT problem.Given a
precondition_fun
, written here for short as \(h\), the first order conditions for the dual energy\[E(K, \epsilon, a, b, f, g) :=- <a,\phi_a^{*}(-f)> + <b, \phi_b^{*}(-g)> - \langle\exp^{f/\epsilon}, K \exp^{g/\epsilon}>\]form the basis of the Sinkhorn algorithm. To differentiate optimal solutions to that problem, we exploit the fact that \(h(\nabla E = 0)\) and differentiate that identity to recover variations (Jacobians) of optimal solutions \(f^\star, g^\star$\) as a function of changes in the inputs. The Jacobian of \(h(\nabla_{f,g} E = 0)\) is a linear operator which, if it were to be instantiated as a matrix, would be of size \((n+m) \times (n+m)\). When \(h\) is the identity, that matrix is the Hessian of \(E\), is symmetric and negative-definite (\(E\) is concave) and is structured as \([A, B; B^T, D]\). More generally, for other functions \(h\), the Jacobian of these preconditioned first order conditions is no longer symmetric (except if
a==b
), and has now a structure as \([A, B; C, D]\). That system can be still inverted more generic solvers. By default, \(h = \epsilon \log\), as proposed in [Cuturi et al., 2020].In both cases \(A\) and \(D\) are diagonal matrices, equal to the row and column marginals of the coupling respectively, multiplied by the derivatives of \(h\) evaluated at those marginals, corrected (if handling the unbalanced case) by the second derivative of the part of the objective that ties potentials to the marginals (terms in
phi_star
). When \(h\) is the identity, \(B\) and \(B^T\) are equal respectively to the OT matrix and its transpose, i.e. \(n \times m\) and \(m \times n\) matrices. When \(h\) is not the identity, \(B\) (resp. \(C\)) is equal to the OT matrix (resp. its transpose), rescaled on the left by the application elementwise of \(h'\) to the row (respectively column) marginal sum of the transport.Note that we take great care in not instantiating these transport matrices, to rely instead on calls to the
app_transport
method from theGeometry
objectgeom
(which will either use potentials or scalings, depending onlse_mode
)The Jacobian’s diagonal + off-diagonal blocks structure allows to exploit Schur complements. Depending on the sizes involved, it is better to instantiate the Schur complement of the first or of the second diagonal block.
These linear systems are solved using the user-defined
solver
, using by defaultlineax
solvers when available, or falling back onjax
when not.- Parameters:
gr (
Tuple
[Array
,Array
]) – 2-tuple, (vector of sizen
, vector of sizem
).ot_prob (
LinearProblem
) – the instantiation of the regularized transport problem.f (
Array
) – potential, w.r.t marginal a.g (
Array
) – potential, w.r.t marginal b.lse_mode (
bool
) – bool, log-sum-exp mode if True, kernel else.
- Return type:
- Returns:
A tuple of two vectors, of the same size as
gr
.