ott.solvers.linear.implicit_differentiation.solve_jax_cg#
- ott.solvers.linear.implicit_differentiation.solve_jax_cg(lin, b, lin_t=None, symmetric=False, ridge_identity=0.0, ridge_kernel=0.0, **kwargs)[source]#
Wrapper around JAX native linear solvers.
- Parameters:
b (
Array
) – vector. Returned x is such that lin(x)=blin_t (
Optional
[Callable
[[Array
],Array
]]) – Linear operator, corresponding to transpose of lin.symmetric (
bool
) – whether lin is symmetric.ridge_kernel (
float
) – promotes zero-sum solutions. Only use if tau_a = tau_b = 1.0ridge_identity (
float
) – handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close).
- Return type: