ott.core.neuraldual.NeuralDualSolver
ott.core.neuraldual.NeuralDualSolver#
- class ott.core.neuraldual.NeuralDualSolver(input_dim, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=100, num_inner_iters=10, valid_freq=100, log_freq=100, logging=False, seed=0, pos_weights=True, beta=1.0)[source]#
Solver of the ICNN-based Kantorovich dual.
The algorithm is described in: Optimal transport mapping via input convex neural networks, Makkuva-Taghvaei-Lee-Oh, ICML’20. http://proceedings.mlr.press/v119/makkuva20a/makkuva20a.pdf
- Parameters
input_dim (
int
) – input dimensionality of data required for network initneural_f (
Optional
[Module
]) – network architecture for potential fneural_g (
Optional
[Module
]) – network architecture for potential goptimizer_f (
Optional
[GradientTransformation
]) – optimizer function for potential foptimizer_g (
Optional
[GradientTransformation
]) – optimizer function for potential gnum_train_iters (
int
) – number of total training iterationsnum_inner_iters (
int
) – number of training iterations of g per iteration of fvalid_freq (
int
) – frequency with which model is validatedlog_freq (
int
) – frequency with training and validation are loggedlogging (
bool
) – option to return logsseed (
int
) – random seed for network initialiationspos_weights (
bool
) – option to train networks with potitive weights or regularizerbeta (
int
) – regularization parameter when not training with positive weights
- Returns
the NeuralDual containing the optimal dual potentials f and g
Methods
clip_weights_icnn
(params)create_train_state
(rng, model, optimizer, input)Create initial TrainState.
get_step_fn
(train[, to_optimize])Create a one-step training and evaluation function.
penalize_weights_icnn
(params)setup
(rng, neural_f, neural_g, input_dim, ...)Setup all components required to train the NeuralDual.
train_neuraldual
(trainloader_source, ...)Implementation of the training and validation script.