class ott.solvers.linear.implicit_differentiation.ImplicitDiff(solver=None, solver_kwargs=None, symmetric=False, precondition_fun=None)[source]#

Implicit differentiation of Sinkhorn algorithm.

  • solver (Optional[Callable[[Callable[[Array], Array], Array, Optional[Callable[[Array], Array]], bool], Array]]) – Callable to compute the solution to a linear problem. The callable expects a linear function, a vector, optionally another linear function that implements the transpose of that function, and a boolean flag to specify symmetry. This solver is by default one of lineax.CG or lineax.NormalCG solvers, if the package can be imported, as described in solve_lineax(). The jax alternative is described in solve_jax_cg(). Note that lineax solvers handle better poorly conditioned problems, which arise typically when differentiating the solutions of balanced OT problems (when tau_a==tau_b==1.0). Relying on solve_jax_cg() for such cases might require hand-tuning ridge parameters, in particular ridge_kernel and ridge_identity as described in its doc. These parameters can be passed using solver_kwargs below.

  • solver_kwargs (Optional[Dict[str, Any]]) – keyword arguments passed on to the solver.

  • symmetric (bool) – flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when tau_a==tau_b, and when a == b, or the precondition_fun is the identity. The flag is False by default, and is also tested against tau_a==tau_b. It needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric.

  • precondition_fun (Optional[Callable[[Array], Array]]) – Function used to precondition, on both sides, the linear system derived from first-order conditions of the regularized OT problem. That linear system typically involves an equality between marginals (or simple transform of these marginals when the problem is unbalanced) and another function of the potentials. When that function is specified, that function is applied on both sides of the equality, before being further differentiated to provide the Jacobians needed for implicit function theorem differentiation.


first_order_conditions(prob, f, g, lse_mode)

Compute vector of first order conditions for the reg-OT problem.

gradient(prob, f, g, lse_mode, gr)

Apply VJP to recover gradient in reverse mode differentiation.


param kwargs:

solve(gr, ot_prob, f, g, lse_mode)

Apply minus inverse of [hessian reg_ot_cost w.r.t.