- class ott.solvers.linear.implicit_differentiation.ImplicitDiff(solver=None, solver_kwargs=None, symmetric=False, precondition_fun=None)#
Implicit differentiation of Sinkhorn algorithm.
Array]]) – Callable to compute the solution to a linear problem. The callable expects a linear function, a vector, optionally another linear function that implements the transpose of that function, and a boolean flag to specify symmetry. This solver is by default one of
lineax.NormalCGsolvers, if the package can be imported, as described in
jaxalternative is described in
solve_jax_cg(). Note that lineax solvers handle better poorly conditioned problems, which arise typically when differentiating the solutions of balanced OT problems (when
tau_a==tau_b==1.0). Relying on
solve_jax_cg()for such cases might require hand-tuning ridge parameters, in particular
ridge_identityas described in its doc. These parameters can be passed using
bool) – flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when
tau_a==tau_b, and when
a == b, or the precondition_fun is the identity. The flag is False by default, and is also tested against
tau_a==tau_b. It needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric.
Array]]) – Function used to precondition, on both sides, the linear system derived from first-order conditions of the regularized OT problem. That linear system typically involves an equality between marginals (or simple transform of these marginals when the problem is unbalanced) and another function of the potentials. When that function is specified, that function is applied on both sides of the equality, before being further differentiated to provide the Jacobians needed for implicit function theorem differentiation.
first_order_conditions(prob, f, g, lse_mode)
Compute vector of first order conditions for the reg-OT problem.
gradient(prob, f, g, lse_mode, gr)
Apply VJP to recover gradient in reverse mode differentiation.
- param kwargs:
solve(gr, ot_prob, f, g, lse_mode)
Apply minus inverse of [hessian