The ott.math module holds low level computational primitives that appear in some more advanced optimal transport problems. ott.math.fixed_point_loop implements a fixed-point iteration while loop that can be automatically differentiated, and which might be of more general interest to other JAX users. ott.math.matrix_square_root contains an implementation of the matrix square-root using the Newton-Schulz iterations. That implementation is itself differentiable using either implicit differentiation or unrolling of the updates of these iterations. ott.math.utils contains various low-level routines re-implemented for their usage in JAX. Of particular interest are the custom jvp/vjp re-implementations for logsumexp and norm that have a behavior that differs, in terms of differentiability, from the standard JAX implementations.

Fixed-point Iteration#

fixed_point_loop.fixpoint_iter(cond_fn, ...)

Implementation of a fixed point loop.

Matrix Square Root#

matrix_square_root.sqrtm(x[, threshold, ...])

Higham algorithm to compute matrix square root of p.d.


utils.norm(x[, ord, axis, keepdims])

Computes order ord norm of vector, using jnp.linalg in forward pass.


utils.softmin(x, gamma[, axis])

Soft-min operator.

utils.lambertw(z[, tol, max_iter])

Principal branch of the Lambert W function.