ott.solvers.nn.neuraldual.NeuralDualSolver#

class ott.solvers.nn.neuraldual.NeuralDualSolver(input_dim, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=100, num_inner_iters=10, valid_freq=100, log_freq=100, logging=False, seed=0, pos_weights=True, beta=1.0)[source]#

Solver of the ICNN-based Kantorovich dual.

Learn the optimal transport between two distributions, denoted source and target, respectively. This is achieved by parameterizing the two Kantorovich potentials, g and f, by two input convex neural networks. \(\nabla g\) hereby transports source to target cells, and \(\nabla f\) target to source cells.

Original algorithm is described in [Makkuva et al., 2020].

Parameters
  • input_dim (int) – input dimensionality of data required for network init

  • neural_f (Optional[Module]) – network architecture for potential f

  • neural_g (Optional[Module]) – network architecture for potential g

  • optimizer_f (Union[Array, Iterable[ArrayTree], Mapping[Any, ArrayTree], None]) – optimizer function for potential f

  • optimizer_g (Union[Array, Iterable[ArrayTree], Mapping[Any, ArrayTree], None]) – optimizer function for potential g

  • num_train_iters (int) – number of total training iterations

  • num_inner_iters (int) – number of training iterations of g per iteration of f

  • valid_freq (int) – frequency with which model is validated

  • log_freq (int) – frequency with training and validation are logged

  • logging (bool) – option to return logs

  • seed (int) – random seed for network initializations

  • pos_weights (bool) – option to train networks with positive weights or regularizer

  • beta (float) – regularization parameter when not training with positive weights

Methods

get_step_fn(train[, to_optimize])

Create a one-step training and evaluation function.

setup(rng, neural_f, neural_g, input_dim, ...)

Setup all components required to train the network.

to_dual_potentials()

Return the Kantorovich dual potentials from the trained potentials.

train_neuraldual(trainloader_source, ...)

Implementation of the training and validation script.