ott.math module holds low level computational primitives that
appear in some more advanced optimal transport problems.
ott.math.fixed_point_loop implements a fixed-point iteration while loop
that can be automatically differentiated, and which might
be of more general interest to other JAX users.
ott.math.matrix_square_root contains an implementation of the
matrix square-root using the Newton-Schulz iterations. That implementation is
itself differentiable using either implicit differentiation or unrolling of the
updates of these iterations.
ott.math.utils contains various low-level routines re-implemented for
their usage in JAX. Of particular interest are the custom jvp/vjp
re-implementations for logsumexp and norm that have a behavior that differs,
in terms of differentiability, from the standard JAX implementations.
Implementation of a fixed point loop.
Matrix Square Root#
Higham algorithm to compute matrix square root of p.d.
Computes order ord norm of vector, using jnp.linalg in forward pass.