ott.solvers.linear.implicit_differentiation.ImplicitDiff.solver_fun#
- ImplicitDiff.solver_fun(b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)#
Use Conjugate Gradient iteration to solve
Ax = b
.The numerics of JAX’s
cg
should exact match SciPy’scg
(up to numerical precision), but note that the interface is slightly different: you need to supply the linear operatorA
as a function instead of a sparse matrix orLinearOperator
.Derivatives of
cg
are implemented via implicit differentiation with anothercg
solve, rather than by differentiating through the solver. They will be accurate only if both solves converge.- Parameters:
A (ndarray, function, or matmul-compatible object) – 2D array or function that calculates the linear map (matrix-vector product)
Ax
when called likeA(x)
orA @ x
.A
must represent a hermitian, positive definite matrix, and must return array(s) with the same structure and shape as its argument.b (array or tree of arrays) – Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape.
x0 (array or tree of arrays) – Starting guess for the solution. Must have the same structure as
b
.tol (float, optional) – Tolerances for convergence,
norm(residual) <= max(tol*norm(b), atol)
. We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly passatol
to SciPy’scg
.atol (float, optional) – Tolerances for convergence,
norm(residual) <= max(tol*norm(b), atol)
. We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly passatol
to SciPy’scg
.maxiter (integer) – Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved.
M (ndarray, function, or matmul-compatible object) – Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance.
- Returns:
x (array or tree of arrays) – The converged solution. Has the same structure as
b
.info (None) – Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy.