ott.neural.methods.neuraldual.W2NeuralDual#
- class ott.neural.methods.neuraldual.W2NeuralDual(dim_data, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=20000, num_inner_iters=1, back_and_forth=None, valid_freq=1000, log_freq=1000, logging=False, rng=None, pos_weights=True, beta=1.0, conjugate_solver=FenchelConjugateLBFGS(gtol=1e-05, max_iter=20, max_linesearch_iter=20, linesearch_type='backtracking', linesearch_init='increase', increase_factor=1.5), amortization_loss='regression', parallel_updates=True)[source]#
Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.
Learn the Wasserstein-2 optimal transport between two measures \(\alpha\) and \(\beta\) in \(n\)-dimensional Euclidean space, denoted source and target, respectively. This is achieved by parameterizing a Kantorovich potential \(f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}\) associated with the \(\alpha\) measure with an
ICNN
or aPotentialMLP
, where \(\nabla f\) transports source to target cells. This potential is learned by optimizing the dual form associated with the negative inner product cost\[\text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] - \mathbb{E}_{y\sim\beta}[f^\star_\theta(y)],\]where \(f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle\) is the convex conjugate. \(\nabla f^\star\) transports from the target to source cells and provides the inverse optimal transport map from \(\beta\) to \(\alpha\). This solver estimates the conjugate \(f^\star\) with a neural approximation \(g\) that is fine-tuned with
FenchelConjugateSolver
, which is a combination further described in [Amos, 2023].The
BasePotential
potentials forneural_f
andneural_g
canboth provide the values of the potentials \(f\) and \(g\), or
one of them can provide the gradient mapping e.g., \(\nabla f\) or \(\nabla g\) where the potential’s value can be obtained via the Fenchel conjugate as discussed in [Amos, 2023].
The potential’s value or gradient mapping is specified via
is_potential
.- Parameters:
dim_data (
int
) – input dimensionality of data required for network initneural_f (
Optional
[BasePotential
]) – network architecture for potential \(f\).neural_g (
Optional
[BasePotential
]) – network architecture for the conjugate potential \(g\approx f^\star\)optimizer_f (
Union
[Array
,ndarray
,bool
,number
,Iterable
[ArrayTree],Mapping
[Any
, ArrayTree],None
]) – optimizer function for potential \(f\)optimizer_g (
Union
[Array
,ndarray
,bool
,number
,Iterable
[ArrayTree],Mapping
[Any
, ArrayTree],None
]) – optimizer function for the conjugate potential \(g\)num_train_iters (
int
) – number of total training iterationsnum_inner_iters (
int
) – number of training iterations of \(g\) per iteration of \(f\)back_and_forth (
Optional
[bool
]) – alternate between updating the forward and backward directions. Inspired from [Jacobs and Léger, 2020]valid_freq (
int
) – frequency with which model is validatedlog_freq (
int
) – frequency with training and validation are loggedlogging (
bool
) – option to return logsrng (
Optional
[Array
]) – random key used for seeding for network initializationspos_weights (
bool
) – option to train networks with positive weights or regularizerbeta (
float
) – regularization parameter when not training with positive weightsconjugate_solver (
Optional
[FenchelConjugateSolver
]) – numerical solver for the Fenchel conjugate.amortization_loss (
Literal
['objective'
,'regression'
]) – amortization loss for the conjugate \(g\approx f^\star\). Options are ‘objective’ [Makkuva et al., 2020] or ‘regression’ [Amos, 2023].parallel_updates (
bool
) – Update \(f\) and \(g\) at the same time
Methods
get_step_fn
(train, to_optimize)Create a parallel training and evaluation function.
setup
(rng, neural_f, neural_g, dim_data, ...)Setup all components required to train the network.
to_dual_potentials
([finetune_g])Return the Kantorovich dual potentials from the trained potentials.
train_neuraldual_alternating
(...[, callback])Training and validation with alternating updates.
train_neuraldual_parallel
(...[, callback])Training and validation with parallel updates.