ott.math#
The ott.math
module holds low level computational primitives that
appear in some more advanced optimal transport problems.
ott.math.fixed_point_loop
implements a fixed-point iteration while loop
that can be automatically differentiated, and which might
be of more general interest to other JAX users.
ott.math.matrix_square_root
contains an implementation of the
matrix square-root using the Newton-Schulz iterations. That implementation is
itself differentiable using either implicit differentiation or unrolling of the
updates of these iterations.
ott.math.utils
contains various low-level routines re-implemented for
their usage in JAX. Of particular interest are the custom jvp/vjp
re-implementations for logsumexp and norm that have a behavior that differs,
in terms of differentiability, from the standard JAX implementations.
Fixed-point Iteration#
|
Implementation of a fixed point loop. |
Matrix Square Root#
|
Higham algorithm to compute matrix square root of p.d. |
Miscellaneous#
|
Computes order ord norm of vector, using jnp.linalg in forward pass. |
|
Soft-min operator. |
|
Principal branch of the Lambert W function. |