ott.solvers.linear.implicit_differentiation.ImplicitDiff#
- class ott.solvers.linear.implicit_differentiation.ImplicitDiff(solver_fun=<function cg>, ridge_kernel=0.0, ridge_identity=0.0, symmetric=False, precondition_fun=None)[source]#
Implicit differentiation of Sinkhorn algorithm.
- Parameters:
solver_fun (
Callable
[[Array
,Array
],Tuple
[Array
,...
]]) – Callable, should return (solution, …)ridge_kernel (
float
) – promotes zero-sum solutions. only used if tau_a = tau_b = 1.0ridge_identity (
float
) – handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close).symmetric (
bool
) – flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when eithera == b
or the precondition_fun is the identity. False by default, and, at the moment, needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric.precondition_fun (
Optional
[Callable
[[Array
],Array
]]) – TODO(marcocuturi)
Methods
first_order_conditions
(prob, f, g, lse_mode)Compute vector of first order conditions for the reg-OT problem.
gradient
(prob, f, g, lse_mode, gr)Apply VJP to recover gradient in reverse mode differentiation.
replace
(**kwargs)- param kwargs:
solve
(gr, ot_prob, f, g, lse_mode)Apply minus inverse of [hessian
reg_ot_cost
w.r.t.solver_fun
(b[, x0, tol, atol, maxiter, M])Use Conjugate Gradient iteration to solve
Ax = b
.Attributes