This module implements the input-convex neural network [Amos et al., 2017] which can be used to solve OT problems between point-clouds. Other simpler alternatives, such as MLP are implemented, essentially borrowed from the flax library.

Neural Dual#

neuraldual.W2NeuralDual(dim_data[, ...])

Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.


models.ModelBase([parent, name])

Base class for the neural solver models.

models.ICNN(dim_data, dim_hidden[, ...])

Input convex neural network (ICNN) architecture with initialization.

models.MLP(dim_hidden[, is_potential, ...])

A generic, typically not-convex (w.r.t input) MLP.

Conjugate Solvers#

conjugate_solvers.ConjugateResults(val, ...)

Holds the results of numerically conjugating a function.


Abstract conjugate solver class.


Solve for the conjugate using jaxopt.LBFGS.


losses.monge_gap(map_fn, reference_points[, ...])

Monge gap regularizer [Uscidda and Cuturi, 2023].

losses.monge_gap_from_samples(source, target)

Monge gap, instantiated in terms of samples before / after applying map.