ott.math#

The ott.math module holds low level computational primitives that appear in some more advanced optimal transport problems. Function fixpoint_iter() implements a fixed-point iteration while loop that can be automatically differentiated, and which might be of more general interest to other JAX users. Function sqrtm() contains an implementation of the matrix square-root using the Newton-Schulz iterations. That implementation is itself differentiable using either implicit differentiation or unrolling of the updates of these iterations. ott.math.utils contains various low-level routines re-implemented for their usage in jax. Of particular interest are the custom jvp/vjp re-implementations for logsumexp and norm that have a behavior that differs, in terms of differentiability, from the standard jax implementations.

Fixed-point Iteration#

fixed_point_loop.fixpoint_iter(cond_fn, ...)

Implementation of a fixed point loop.

Matrix Square Root#

matrix_square_root.sqrtm

Miscellaneous#

lbfgs(fun, x_init[, max_iter, tol])

Runs optax's L-BFGS optimization on function.

legendre(fun, **kwargs)

Legendre (Fenchel) transform of a function.

velocity_from_brenier_potential(potential, ...)

Get optimal time-dependent velocity field from Brenier potential.

utils.norm

utils.logsumexp

utils.softmin

utils.lambertw