ott.math#
The ott.math module holds low level computational primitives that
appear in some more advanced optimal transport problems.
Function fixpoint_iter() implements a
fixed-point iteration while loop that can be automatically differentiated,
and which might be of more general interest to other JAX users.
Function sqrtm() contains an implementation
of the matrix square-root using the Newton-Schulz iterations. That
implementation is itself differentiable using either
implicit differentiation or unrolling of the updates of these
iterations. ott.math.utils contains various low-level routines
re-implemented for their usage in jax. Of particular interest are
the custom jvp/vjp re-implementations for logsumexp and norm that have
a behavior that differs, in terms of differentiability, from the standard
jax implementations.
Fixed-point Iteration#
|
Implementation of a fixed point loop. |
Matrix Square Root#
|
Miscellaneous#
|
Runs optax's L-BFGS optimization on function. |
|
Legendre (Fenchel) transform of a function. |
|
Get optimal time-dependent velocity field from Brenier potential. |
|
|
|
|
|