ott.math.matrix_square_root.sqrtm(x, threshold=1e-06, min_iterations=0, inner_iterations=10, max_iterations=1000, regularization=1e-06) = <jax._src.custom_derivatives.custom_vjp object>[source]#

Higham algorithm to compute matrix square root of p.d. matrix.

See [Higham, 1997], eq. 2.6b

  • x (Array) – a (batch of) square p.s.d. matrices of the same size.

  • threshold (float) – convergence tolerance threshold for Newton-Schulz iterations.

  • min_iterations (int) – min number of iterations after which error is computed.

  • inner_iterations (int) – error is re-evaluated every inner_iterations iterations.

  • max_iterations (int) – max number of iterations.

  • regularization (float) – small regularizer added to norm of x, before normalization.

Return type:

Tuple[Array, Array, Array]


Square root matrix of x (or x’s if batch), its inverse, errors along iterates.