This module implements the input-convex neural network [Amos et al., 2017] which can be used to solve OT problems between point-clouds. Other simpler alternatives, such as MLP are implemented, essentially borrowed from the flax library.

Neural Dual#

neuraldual.W2NeuralDual(dim_data[, ...])

Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.


models.ModelBase([parent, name])

Base class for the neural solver models.

models.ICNN(dim_data, dim_hidden[, ...])

Input convex neural network (ICNN) architecture with initialization.

models.MLP(dim_hidden[, is_potential, ...])

A non-convex MLP.

Conjugate Solvers#

conjugate_solvers.ConjugateResults(val, ...)

Holds the results of numerically conjugating a function.


Abstract conjugate solver class.


Solve for the conjugate using jaxopt.LBFGS.