ott.solvers.linear.implicit_differentiation.ImplicitDiff#

class ott.solvers.linear.implicit_differentiation.ImplicitDiff(solver_fun=<function cg>, ridge_kernel=0.0, ridge_identity=0.0, symmetric=False, precondition_fun=None)[source]#

Implicit differentiation of Sinkhorn algorithm.

Parameters:
  • solver_fun (Callable[[Array, Array], Tuple[Array, ...]]) – Callable, should return (solution, …)

  • ridge_kernel (float) – promotes zero-sum solutions. only used if tau_a = tau_b = 1.0

  • ridge_identity (float) – handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close).

  • symmetric (bool) – flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when either a == b or the precondition_fun is the identity. False by default, and, at the moment, needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric.

  • precondition_fun (Optional[Callable[[Array], Array]]) – TODO(marcocuturi)

Methods

first_order_conditions(prob, f, g, lse_mode)

Compute vector of first order conditions for the reg-OT problem.

gradient(prob, f, g, lse_mode, gr)

Apply VJP to recover gradient in reverse mode differentiation.

replace(**kwargs)

param kwargs:

solve(gr, ot_prob, f, g, lse_mode)

Apply minus inverse of [hessian reg_ot_cost w.r.t.

solver_fun(b[, x0, tol, atol, maxiter, M])

Use Conjugate Gradient iteration to solve Ax = b.

Attributes

precondition_fun

ridge_identity

ridge_kernel

symmetric