Source code for ott.solvers.linear.implicit_differentiation

# Copyright OTT-JAX
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp

from ott import utils
from ott.math import unbalanced_functions as uf

  from ott.problems.linear import linear_problem

LinOp_t = Callable[[jnp.ndarray], jnp.ndarray]
Solver_t = Callable[[LinOp_t, jnp.ndarray, Optional[LinOp_t], bool],

__all__ = ["ImplicitDiff", "solve_jax_cg"]

[docs] @utils.register_pytree_node class ImplicitDiff: """Implicit differentiation of Sinkhorn algorithm. Args: solver: Callable to compute the solution to a linear problem. The callable expects a linear function, a vector, optionally another linear function that implements the transpose of that function, and a boolean flag to specify symmetry. This solver is by default one of :class:`lineax.CG` or :class:`lineax.NormalCG` solvers, if the package can be imported, as described in :func:`~ott.solvers.linear.lineax_implicit.solve_lineax`. The :mod:`jax` alternative is described in :func:`~ott.solvers.linear.implicit_differentiation.solve_jax_cg`. Note that `lineax` solvers handle better poorly conditioned problems, which arise typically when differentiating the solutions of balanced OT problems (when ``tau_a==tau_b==1.0``). Relying on :func:`~ott.solvers.linear.implicit_differentiation.solve_jax_cg` for such cases might require hand-tuning ridge parameters, in particular ``ridge_kernel`` and ``ridge_identity`` as described in its doc. These parameters can be passed using ``solver_kwargs`` below. solver_kwargs: keyword arguments passed on to the solver. symmetric: flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when ``tau_a==tau_b``, and when ``a == b``, or the precondition_fun is the identity. The flag is False by default, and is also tested against ``tau_a==tau_b``. It needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric. precondition_fun: Function used to precondition, on both sides, the linear system derived from first-order conditions of the regularized OT problem. That linear system typically involves an equality between marginals (or simple transform of these marginals when the problem is unbalanced) and another function of the potentials. When that function is specified, that function is applied on both sides of the equality, before being further differentiated to provide the Jacobians needed for implicit function theorem differentiation. """ solver: Optional[Solver_t] = None solver_kwargs: Optional[Dict[str, Any]] = None symmetric: bool = False precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None
[docs] def solve( self, gr: Tuple[jnp.ndarray, jnp.ndarray], ot_prob: "linear_problem.LinearProblem", f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool, ) -> jnp.ndarray: r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``]. This function is used to carry out implicit differentiation of ``sinkhorn`` outputs, notably optimal potentials ``f`` and ``g``. That differentiation requires solving a linear system, using (and inverting) the Jacobian of (preconditioned) first-order conditions w.r.t. the reg-OT problem. Given a ``precondition_fun``, written here for short as :math:`h`, the first order conditions for the dual energy .. math:: E(K, \epsilon, a, b, f, g) :=- <a,\phi_a^{*}(-f)> + <b, \phi_b^{*}(-g)> - \langle\exp^{f/\epsilon}, K \exp^{g/\epsilon}> form the basis of the Sinkhorn algorithm. To differentiate optimal solutions to that problem, we exploit the fact that :math:`h(\nabla E = 0)` and differentiate that identity to recover variations (Jacobians) of optimal solutions :math:`f^\star, g^\star$` as a function of changes in the inputs. The Jacobian of :math:`h(\nabla_{f,g} E = 0)` is a linear operator which, if it were to be instantiated as a matrix, would be of size :math:`(n+m) \times (n+m)`. When :math:`h` is the identity, that matrix is the Hessian of :math:`E`, is symmetric and negative-definite (:math:`E` is concave) and is structured as :math:`[A, B; B^T, D]`. More generally, for other functions :math:`h`, the Jacobian of these preconditioned first order conditions is no longer symmetric (except if ``a==b``), and has now a structure as :math:`[A, B; C, D]`. That system can be still inverted more generic solvers. By default, :math:`h = \epsilon \log`, as proposed in :cite:`cuturi:20a`. In both cases :math:`A` and :math:`D` are diagonal matrices, equal to the row and column marginals respectively, multiplied by the derivatives of :math:`h` evaluated at those marginals, corrected (if handling the unbalanced case) by the second derivative of the part of the objective that ties potentials to the marginals (terms in ``phi_star``). When :math:`h` is the identity, :math:`B` and :math:`B^T` are equal respectively to the OT matrix and its transpose, i.e. :math:`n \times m` and :math:`m \times n` matrices. When :math:`h` is not the identity, :math:`B` (resp. :math:`C`) is equal to the OT matrix (resp. its transpose), rescaled on the left by the application elementwise of :math:`h'` to the row (respectively column) marginal sum of the transport. Note that we take great care in not instantiating these transport matrices, to rely instead on calls to the ``app_transport`` method from the ``Geometry`` object ``geom`` (which will either use potentials or scalings, depending on ``lse_mode``) The Jacobian's diagonal + off-diagonal blocks structure allows to exploit Schur complements. Depending on the sizes involved, it is better to instantiate the Schur complement of the first or of the second diagonal block. These linear systems are solved using the user-defined ``solver``, using by default :mod:`lineax` solvers when available, or falling back on :mod:`jax` when not. Args: gr: 2-tuple, (vector of size ``n``, vector of size ``m``). ot_prob: the instantiation of the regularized transport problem. f: potential, w.r.t marginal a. g: potential, w.r.t marginal b. lse_mode: bool, log-sum-exp mode if True, kernel else. Returns: A tuple of two vectors, of the same size as ``gr``. """ solver = _get_solver() if self.solver is None else self.solver solver_kwargs = {} if self.solver_kwargs is None else self.solver_kwargs geom = ot_prob.geom marginal_a, marginal_b, app_transport = ( ot_prob.get_transport_functions(lse_mode) ) if self.precondition_fun is None: precond_fun = lambda x: geom.epsilon * jnp.log(x) symmetric = False else: precond_fun = self.precondition_fun symmetric = self.symmetric derivative = jax.vmap(jax.grad(precond_fun)) n, m = geom.shape # pylint: disable=g-long-lambda vjp_fg = lambda z: app_transport( f, g, z * derivative(marginal_b(f, g)), axis=1 ) / geom.epsilon vjp_gf = lambda z: app_transport( f, g, z * derivative(marginal_a(f, g)), axis=0 ) / geom.epsilon if not symmetric: vjp_fgt = lambda z: app_transport( f, g, z, axis=0 ) * derivative(marginal_b(f, g)) / geom.epsilon vjp_gft = lambda z: app_transport( f, g, z, axis=1 ) * derivative(marginal_a(f, g)) / geom.epsilon diag_hess_a = ( marginal_a(f, g) * derivative(marginal_a(f, g)) / geom.epsilon + uf.diag_jacobian_of_marginal_fit( ot_prob.a, f, ot_prob.tau_a, geom.epsilon, derivative ) ) diag_hess_b = ( marginal_b(f, g) * derivative(marginal_b(f, g)) / geom.epsilon + uf.diag_jacobian_of_marginal_fit( ot_prob.b, g, ot_prob.tau_b, geom.epsilon, derivative ) ) n, m = geom.shape # TODO(cuturi) consider materializing linear operator schur if size allows. # Forks on using Schur complement of either A or D, depending on size. if n > m: # if n is bigger, run m x m linear system. inv_vjp_ff = lambda z: z / diag_hess_a vjp_gg = lambda z: z * diag_hess_b schur = lambda z: vjp_gg(z) - vjp_gf(inv_vjp_ff(vjp_fg(z))) if not symmetric: schur_t = lambda z: vjp_gg(z) - vjp_fgt(inv_vjp_ff(vjp_gft(z))) else: schur_t = None res = gr[1] - vjp_gf(inv_vjp_ff(gr[0])) sch = solver(schur, res, schur_t, symmetric, **solver_kwargs) vjp_gr_f = inv_vjp_ff(gr[0] - vjp_fg(sch)) vjp_gr_g = sch else: vjp_ff = lambda z: z * diag_hess_a inv_vjp_gg = lambda z: z / diag_hess_b schur = lambda z: vjp_ff(z) - vjp_fg(inv_vjp_gg(vjp_gf(z))) if not symmetric: schur_t = lambda z: vjp_ff(z) - vjp_gft(inv_vjp_gg(vjp_fgt(z))) else: schur_t = None res = gr[0] - vjp_fg(inv_vjp_gg(gr[1])) sch = solver(schur, res, schur_t, symmetric, **solver_kwargs) vjp_gr_g = inv_vjp_gg(gr[1] - vjp_gf(sch)) vjp_gr_f = sch return jnp.concatenate((-vjp_gr_f, -vjp_gr_g))
[docs] def first_order_conditions( self, prob, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool ): r"""Compute vector of first order conditions for the reg-OT problem. The output of this vector should be close to zero at optimality. Upon completion of the Sinkhorn forward pass, its norm (using the norm parameter defined using ``norm_error``) should be below the threshold parameter. This error will be itself assumed to be close to zero when using implicit differentiation. Args: prob: definition of the linear optimal transport problem. f: jnp.ndarray, first potential g: jnp.ndarray, second potential lse_mode: bool Returns: a jnp.ndarray of size (size of ``n + m``) quantifying deviation to optimality for variables ``f`` and ``g``. """ geom = prob.geom marginal_a, marginal_b, _ = prob.get_transport_functions(lse_mode) grad_a = uf.grad_of_marginal_fit(prob.a, f, prob.tau_a, geom.epsilon) grad_b = uf.grad_of_marginal_fit(prob.b, g, prob.tau_b, geom.epsilon) if self.precondition_fun is None: precond_fun = lambda x: geom.epsilon * jnp.log(x) else: precond_fun = self.precondition_fun result_a = jnp.where( prob.a > 0, precond_fun(marginal_a(f, g)) - precond_fun(grad_a), 0.0 ) result_b = jnp.where( prob.b > 0, precond_fun(marginal_b(f, g)) - precond_fun(grad_b), 0.0 ) return jnp.concatenate((result_a, result_b))
[docs] def gradient( self, prob: "linear_problem.LinearProblem", f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool, gr: Tuple[jnp.ndarray, jnp.ndarray] ) -> "linear_problem.LinearProblem": """Apply VJP to recover gradient in reverse mode differentiation.""" # Applies first part of vjp to gr: inverse part of implicit function theorem vjp_gr = self.solve(gr, prob, f, g, lse_mode) # Instantiates vjp of first order conditions of the objective, as a # function of geom, a and b parameters (against which we differentiate) foc_prob = lambda prob: self.first_order_conditions(prob, f, g, lse_mode) # Carries pullback onto original inputs, here geom, a and b. _, pull_prob = jax.vjp(foc_prob, prob) return pull_prob(vjp_gr)
[docs] def replace(self, **kwargs: Any) -> "ImplicitDiff": # noqa: D102 return dataclasses.replace(self, **kwargs)
[docs] def solve_jax_cg( lin: LinOp_t, b: jnp.ndarray, lin_t: Optional[LinOp_t] = None, symmetric: bool = False, ridge_identity: float = 0.0, ridge_kernel: float = 0.0, **kwargs: Any ) -> jnp.ndarray: """Wrapper around JAX native linear solvers. Args: lin: Linear operator b: vector. Returned `x` is such that `lin(x)=b` lin_t: Linear operator, corresponding to transpose of `lin`. symmetric: whether `lin` is symmetric. ridge_kernel: promotes zero-sum solutions. Only use if `tau_a = tau_b = 1.0` ridge_identity: handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close). kwargs: arguments passed to :func:`` """ op = lin if symmetric else lambda x: lin_t(lin(x)) if ridge_kernel > 0.0 or ridge_identity > 0.0: lin_reg = lambda x: op(x) + ridge_kernel * jnp.sum(x) + ridge_identity * x else: lin_reg = op return, b, **kwargs)[0]
def _get_solver() -> Solver_t: """Get lineax solver when possible, default to jax.scipy else.""" try: from ott.solvers.linear import lineax_implicit return lineax_implicit.solve_lineax except ImportError: return solve_jax_cg