ott.solvers.nn#

This module implements the input-convex neural network [Amos et al., 2017] which can be used to solve OT problems between point-clouds. Other simpler alternatives, such as MLP are implemented, essentially borrowed from the flax library.

Neural Dual#

neuraldual.W2NeuralDual(dim_data[, ...])

Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.

Models#

models.ModelBase([parent, name])

Base class for the neural solver models.

models.ICNN(dim_data, dim_hidden[, ...])

Input convex neural network (ICNN) architecture with initialization.

models.MLP(dim_hidden[, is_potential, ...])

A non-convex MLP.

Conjugate Solvers#

conjugate_solvers.ConjugateResults(val, ...)

Holds the results of numerically conjugating a function.

conjugate_solvers.FenchelConjugateSolver()

Abstract conjugate solver class.

conjugate_solvers.FenchelConjugateLBFGS([...])

Solve for the conjugate using jaxopt.LBFGS.