- class ott.solvers.nn.neuraldual.NeuralDualSolver(input_dim, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=100, num_inner_iters=10, valid_freq=100, log_freq=100, logging=False, seed=0, pos_weights=True, beta=1.0)#
Solver of the ICNN-based Kantorovich dual.
Learn the optimal transport between two distributions, denoted source and target, respectively. This is achieved by parameterizing the two Kantorovich potentials, g and f, by two input convex neural networks. \(\nabla g\) hereby transports source to target cells, and \(\nabla f\) target to source cells.
Original algorithm is described in [Makkuva et al., 2020].
int) – input dimensionality of data required for network init
int) – number of total training iterations
int) – number of training iterations of g per iteration of f
int) – frequency with which model is validated
int) – frequency with training and validation are logged
bool) – option to return logs
int) – random seed for network initializations
bool) – option to train networks with positive weights or regularizer
float) – regularization parameter when not training with positive weights
Create a one-step training and evaluation function.
setup(rng, neural_f, neural_g, input_dim, ...)
Setup all components required to train the network.
Return the Kantorovich dual potentials from the trained potentials.
Implementation of the training and validation script.