ott.solvers.linear.lineax_implicit.solve_lineax#
- ott.solvers.linear.lineax_implicit.solve_lineax(lin, b, lin_t=None, symmetric=False, nonsym_solver=None, ridge_identity=0.0, ridge_kernel=0.0, **kwargs)[source]#
Solve a linear system using conjugate gradients.
This implementation uses a JAX-native CG solver that works correctly inside JAX transformations (VJP backward pass), avoiding equinox closure conversion issues that affect lineax on certain JAX versions.
- Parameters:
lin (
Callable) – Linear operatorb (
Array) – vector. Returned x is such that lin(x)=blin_t (
Optional[Callable]) – Linear operator, corresponding to transpose of lin.symmetric (
bool) – whether lin is symmetric.nonsym_solver (
Optional[Any]) – unused, kept for API compatibility.ridge_kernel (
float) – promotes zero-sum solutions. Only use if tau_a = tau_b = 1.0ridge_identity (
float) – handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close).kwargs (
Any) – arguments passed to the CG solver (rtol, atol, maxiter).
- Return type: