ott.solvers.linear.implicit_differentiation.ImplicitDiff#
- class ott.solvers.linear.implicit_differentiation.ImplicitDiff(solver=None, solver_kwargs=None, symmetric=False, precondition_fun=None)[source]#
Implicit differentiation of Sinkhorn algorithm.
- Parameters:
solver (
Optional
[Callable
[[Callable
[[Array
],Array
],Array
,Optional
[Callable
[[Array
],Array
]],bool
],Array
]]) – Callable to compute the solution to a linear problem. The callable expects a linear function, a vector, optionally another linear function that implements the transpose of that function, and a boolean flag to specify symmetry. This solver is by default one oflineax.CG
orlineax.NormalCG
solvers, if the package can be imported, as described insolve_lineax()
. Thejax
alternative is described insolve_jax_cg()
. Note that lineax solvers handle better poorly conditioned problems, which arise typically when differentiating the solutions of balanced OT problems (whentau_a==tau_b==1.0
). Relying onsolve_jax_cg()
for such cases might require hand-tuning ridge parameters, in particularridge_kernel
andridge_identity
as described in its doc. These parameters can be passed usingsolver_kwargs
below.solver_kwargs (
Optional
[Dict
[str
,Any
]]) – keyword arguments passed on to the solver.symmetric (
bool
) – flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens whentau_a==tau_b
, and whena == b
, or the precondition_fun is the identity. The flag is False by default, and is also tested againsttau_a==tau_b
. It needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric.precondition_fun (
Optional
[Callable
[[Array
],Array
]]) – Function used to precondition, on both sides, the linear system derived from first-order conditions of the regularized OT problem. That linear system typically involves an equality between marginals (or simple transform of these marginals when the problem is unbalanced) and another function of the potentials. When that function is specified, that function is applied on both sides of the equality, before being further differentiated to provide the Jacobians needed for implicit function theorem differentiation.
Methods
first_order_conditions
(prob, f, g, lse_mode)Compute vector of first order conditions for the reg-OT problem.
gradient
(prob, f, g, lse_mode, gr)Apply VJP to recover gradient in reverse mode differentiation.
replace
(**kwargs)solve
(gr, ot_prob, f, g, lse_mode)Apply minus inverse of [hessian
reg_ot_cost
w.r.t.Attributes