class ott.core.neuraldual.NeuralDualSolver(input_dim, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=100, num_inner_iters=10, valid_freq=100, log_freq=100, logging=False, seed=0, pos_weights=True, beta=1.0)[source]#

Solver of the ICNN-based Kantorovich dual.

The algorithm is described in: Optimal transport mapping via input convex neural networks, Makkuva-Taghvaei-Lee-Oh, ICML’20.

  • input_dim (int) – input dimensionality of data required for network init

  • neural_f (Optional[Module]) – network architecture for potential f

  • neural_g (Optional[Module]) – network architecture for potential g

  • optimizer_f (Optional[GradientTransformation]) – optimizer function for potential f

  • optimizer_g (Optional[GradientTransformation]) – optimizer function for potential g

  • num_train_iters (int) – number of total training iterations

  • num_inner_iters (int) – number of training iterations of g per iteration of f

  • valid_freq (int) – frequency with which model is validated

  • log_freq (int) – frequency with training and validation are logged

  • logging (bool) – option to return logs

  • seed (int) – random seed for network initialiations

  • pos_weights (bool) – option to train networks with potitive weights or regularizer

  • beta (int) – regularization parameter when not training with positive weights


the NeuralDual containing the optimal dual potentials f and g



create_train_state(rng, model, optimizer, input)

Create initial TrainState.

get_step_fn(train[, to_optimize])

Create a one-step training and evaluation function.


setup(rng, neural_f, neural_g, input_dim, ...)

Setup all components required to train the NeuralDual.

train_neuraldual(trainloader_source, ...)

Implementation of the training and validation script.