ott.solvers.nn.neuraldual.NeuralDualSolver
ott.solvers.nn.neuraldual.NeuralDualSolver#
- class ott.solvers.nn.neuraldual.NeuralDualSolver(input_dim, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=100, num_inner_iters=10, valid_freq=100, log_freq=100, logging=False, seed=0, pos_weights=True, beta=1.0)[source]#
Solver of the ICNN-based Kantorovich dual.
Learn the optimal transport between two distributions, denoted source and target, respectively. This is achieved by parameterizing the two Kantorovich potentials, g and f, by two input convex neural networks. \(\nabla g\) hereby transports source to target cells, and \(\nabla f\) target to source cells.
Original algorithm is described in [Makkuva et al., 2020].
- Parameters
input_dim (
int
) – input dimensionality of data required for network initneural_f (
Optional
[Module
]) – network architecture for potential fneural_g (
Optional
[Module
]) – network architecture for potential goptimizer_f (
Union
[Array
,Iterable
[ArrayTree],Mapping
[Any
, ArrayTree],None
]) – optimizer function for potential foptimizer_g (
Union
[Array
,Iterable
[ArrayTree],Mapping
[Any
, ArrayTree],None
]) – optimizer function for potential gnum_train_iters (
int
) – number of total training iterationsnum_inner_iters (
int
) – number of training iterations of g per iteration of fvalid_freq (
int
) – frequency with which model is validatedlog_freq (
int
) – frequency with training and validation are loggedlogging (
bool
) – option to return logsseed (
int
) – random seed for network initializationspos_weights (
bool
) – option to train networks with positive weights or regularizerbeta (
float
) – regularization parameter when not training with positive weights
Methods
get_step_fn
(train[, to_optimize])Create a one-step training and evaluation function.
setup
(rng, neural_f, neural_g, input_dim, ...)Setup all components required to train the network.
Return the Kantorovich dual potentials from the trained potentials.
train_neuraldual
(trainloader_source, ...)Implementation of the training and validation script.