ott.neural.solvers.neuraldual.W2NeuralDual

Contents

ott.neural.solvers.neuraldual.W2NeuralDual#

class ott.neural.solvers.neuraldual.W2NeuralDual(dim_data, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, num_train_iters=20000, num_inner_iters=1, back_and_forth=None, valid_freq=1000, log_freq=1000, logging=False, rng=None, pos_weights=True, beta=1.0, conjugate_solver=FenchelConjugateLBFGS(gtol=1e-05, max_iter=20, max_linesearch_iter=20, linesearch_type='backtracking', linesearch_init='increase', increase_factor=1.5), amortization_loss='regression', parallel_updates=True)[source]#

Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.

Learn the Wasserstein-2 optimal transport between two measures \(\alpha\) and \(\beta\) in \(n\)-dimensional Euclidean space, denoted source and target, respectively. This is achieved by parameterizing a Kantorovich potential \(f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}\) associated with the \(\alpha\) measure with an ICNN or MLP, where \(\nabla f\) transports source to target cells. This potential is learned by optimizing the dual form associated with the negative inner product cost

\[\text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] - \mathbb{E}_{y\sim\beta}[f^\star_\theta(y)],\]

where \(f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle\) is the convex conjugate. \(\nabla f^\star\) transports from the target to source cells and provides the inverse optimal transport map from \(\beta\) to \(\alpha\). This solver estimates the conjugate \(f^\star\) with a neural approximation \(g\) that is fine-tuned with FenchelConjugateSolver, which is a combination further described in [Amos, 2023].

The BaseW2NeuralDual potentials for neural_f and neural_g can

  1. both provide the values of the potentials \(f\) and \(g\), or

  2. one of them can provide the gradient mapping e.g., \(\nabla f\) or \(\nabla g\) where the potential’s value can be obtained via the Fenchel conjugate as discussed in [Amos, 2023].

The potential’s value or gradient mapping is specified via is_potential.

Parameters:
  • dim_data (int) – input dimensionality of data required for network init

  • neural_f (Optional[BaseW2NeuralDual]) – network architecture for potential \(f\).

  • neural_g (Optional[BaseW2NeuralDual]) – network architecture for the conjugate potential \(g\approx f^\star\)

  • optimizer_f (Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree], None]) – optimizer function for potential \(f\)

  • optimizer_g (Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree], None]) – optimizer function for the conjugate potential \(g\)

  • num_train_iters (int) – number of total training iterations

  • num_inner_iters (int) – number of training iterations of \(g\) per iteration of \(f\)

  • back_and_forth (Optional[bool]) – alternate between updating the forward and backward directions. Inspired from [Jacobs and Léger, 2020]

  • valid_freq (int) – frequency with which model is validated

  • log_freq (int) – frequency with training and validation are logged

  • logging (bool) – option to return logs

  • rng (Optional[Array]) – random key used for seeding for network initializations

  • pos_weights (bool) – option to train networks with positive weights or regularizer

  • beta (float) – regularization parameter when not training with positive weights

  • conjugate_solver (Optional[FenchelConjugateSolver]) – numerical solver for the Fenchel conjugate.

  • amortization_loss (Literal['objective', 'regression']) – amortization loss for the conjugate \(g\approx f^\star\). Options are ‘objective’ [Makkuva et al., 2020] or ‘regression’ [Amos, 2023].

  • parallel_updates (bool) – Update \(f\) and \(g\) at the same time

Methods

get_step_fn(train, to_optimize)

Create a parallel training and evaluation function.

setup(rng, neural_f, neural_g, dim_data, ...)

Setup all components required to train the network.

to_dual_potentials([finetune_g])

Return the Kantorovich dual potentials from the trained potentials.

train_neuraldual_alternating(...[, callback])

Training and validation with alternating updates.

train_neuraldual_parallel(...[, callback])

Training and validation with parallel updates.