# ott.solvers.linear.implicit_differentiation.ImplicitDiff.solver_fun#

ImplicitDiff.solver_fun(b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)#

Use Conjugate Gradient iteration to solve `Ax = b`.

The numerics of JAX’s `cg` should exact match SciPy’s `cg` (up to numerical precision), but note that the interface is slightly different: you need to supply the linear operator `A` as a function instead of a sparse matrix or `LinearOperator`.

Derivatives of `cg` are implemented via implicit differentiation with another `cg` solve, rather than by differentiating through the solver. They will be accurate only if both solves converge.

Parameters:
• A (ndarray, function, or matmul-compatible object) – 2D array or function that calculates the linear map (matrix-vector product) `Ax` when called like `A(x)` or `A @ x`. `A` must represent a hermitian, positive definite matrix, and must return array(s) with the same structure and shape as its argument.

• b (array or tree of arrays) – Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape.

• x0 (array or tree of arrays) – Starting guess for the solution. Must have the same structure as `b`.

• tol (float, optional) – Tolerances for convergence, `norm(residual) <= max(tol*norm(b), atol)`. We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly pass `atol` to SciPy’s `cg`.

• atol (float, optional) – Tolerances for convergence, `norm(residual) <= max(tol*norm(b), atol)`. We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly pass `atol` to SciPy’s `cg`.

• maxiter (integer) – Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved.

• M (ndarray, function, or matmul-compatible object) – Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance.

Returns:

• x (array or tree of arrays) – The converged solution. Has the same structure as `b`.

• info (None) – Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy.