# ott.solvers.nn.neuraldual.NeuralDualSolver#

class ott.solvers.nn.neuraldual.NeuralDualSolver(input_dim, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=100, num_inner_iters=10, valid_freq=100, log_freq=100, logging=False, seed=0, pos_weights=True, beta=1.0)[source]#

Solver of the ICNN-based Kantorovich dual.

Learn the optimal transport between two distributions, denoted source and target, respectively. This is achieved by parameterizing the two Kantorovich potentials, g and f, by two input convex neural networks. $$\nabla g$$ hereby transports source to target cells, and $$\nabla f$$ target to source cells.

Original algorithm is described in .

Parameters

Methods

 get_step_fn(train[, to_optimize]) Create a one-step training and evaluation function. setup(rng, neural_f, neural_g, input_dim, ...) Setup all components required to train the network. Return the Kantorovich dual potentials from the trained potentials. train_neuraldual(trainloader_source, ...) Implementation of the training and validation script.