ott.solvers.linear.implicit_differentiation.solve_jax_cg

Contents

ott.solvers.linear.implicit_differentiation.solve_jax_cg#

ott.solvers.linear.implicit_differentiation.solve_jax_cg(lin, b, lin_t=None, symmetric=False, ridge_identity=0.0, ridge_kernel=0.0, **kwargs)[source]#

Wrapper around JAX native linear solvers.

Parameters:
  • lin (Callable[[Array], Array]) – Linear operator

  • b (Array) – vector. Returned x is such that lin(x)=b

  • lin_t (Optional[Callable[[Array], Array]]) – Linear operator, corresponding to transpose of lin.

  • symmetric (bool) – whether lin is symmetric.

  • ridge_kernel (float) – promotes zero-sum solutions. Only use if tau_a = tau_b = 1.0

  • ridge_identity (float) – handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close).

  • kwargs (Any) – arguments passed to cg()

Return type:

Array