Source code for ott.solvers.nn.neuraldual

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import (
    Callable,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
)

import jax
import jax.numpy as jnp
import optax
from flax import core

from ott.geometry import costs
from ott.problems.linear import potentials
from ott.solvers.nn import conjugate_solvers, models

__all__ = ["W2NeuralDual"]

Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]]
Callback_t = Callable[[int, potentials.DualPotentials], None]
Conj_t = Optional[conjugate_solvers.FenchelConjugateSolver]


[docs]class W2NeuralDual: r"""Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces. Learn the Wasserstein-2 optimal transport between two measures :math:`\alpha` and :math:`\beta` in :math:`n`-dimensional Euclidean space, denoted source and target, respectively. This is achieved by parameterizing a Kantorovich potential :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}` associated with the :math:`\alpha` measure with an :class:`~ott.solvers.nn.models.ICNN`, :class:`~ott.solvers.nn.models.MLP`, or other :class:`~ott.solvers.nn.models.ModelBase`, where :math:`\nabla f` transports source to target cells. This potential is learned by optimizing the dual form associated with the negative inner product cost .. math:: \text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] - \mathbb{E}_{y\sim\beta}[f^\star_\theta(y)], where :math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle` is the convex conjugate. :math:`\nabla f^\star` transports from the target to source cells and provides the inverse optimal transport map from :math:`\beta` to :math:`\alpha`. This solver estimates the conjugate :math:`f^\star` with a neural approximation :math:`g` that is fine-tuned with :class:`~ott.solvers.nn.conjugate_solvers.FenchelConjugateSolver`, which is a combination further described in :cite:`amos:23`. The :class:`~ott.solvers.nn.models.ModelBase` potentials for ``neural_f`` and ``neural_g`` can 1. both provide the values of the potentials :math:`f` and :math:`g`, or 2. one of them can provide the gradient mapping e.g., :math:`\nabla f` or :math:`\nabla g` where the potential's value can be obtained via the Fenchel conjugate as discussed in :cite:`amos:23`. The potential's value or gradient mapping is specified via :attr:`~ott.solvers.nn.models.ModelBase.is_potential`. Args: dim_data: input dimensionality of data required for network init neural_f: network architecture for potential :math:`f`. neural_g: network architecture for the conjugate potential :math:`g\approx f^\star` optimizer_f: optimizer function for potential :math:`f` optimizer_g: optimizer function for the conjugate potential :math:`g` num_train_iters: number of total training iterations num_inner_iters: number of training iterations of :math:`g` per iteration of :math:`f` back_and_forth: alternate between updating the forward and backward directions. Inspired from :cite:`jacobs:20` valid_freq: frequency with which model is validated log_freq: frequency with training and validation are logged logging: option to return logs rng: random key used for seeding for network initializations pos_weights: option to train networks with positive weights or regularizer beta: regularization parameter when not training with positive weights conjugate_solver: numerical solver for the Fenchel conjugate. amortization_loss: amortization loss for the conjugate :math:`g\approx f^\star`. Options are `'objective'` :cite:`makkuva:20` or `'regression'` :cite:`amos:23`. parallel_updates: Update :math:`f` and :math:`g` at the same time """ def __init__( self, dim_data: int, neural_f: Optional[models.ModelBase] = None, neural_g: Optional[models.ModelBase] = None, optimizer_f: Optional[optax.OptState] = None, optimizer_g: Optional[optax.OptState] = None, num_train_iters: int = 20000, num_inner_iters: int = 1, back_and_forth: Optional[bool] = None, valid_freq: int = 1000, log_freq: int = 1000, logging: bool = False, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), pos_weights: bool = True, beta: float = 1.0, conjugate_solver: Conj_t = conjugate_solvers.DEFAULT_CONJUGATE_SOLVER, amortization_loss: Literal["objective", "regression"] = "regression", parallel_updates: bool = True, ): self.num_train_iters = num_train_iters self.num_inner_iters = num_inner_iters self.back_and_forth = back_and_forth self.valid_freq = valid_freq self.log_freq = log_freq self.logging = logging self.pos_weights = pos_weights self.beta = beta self.parallel_updates = parallel_updates self.conjugate_solver = conjugate_solver self.amortization_loss = amortization_loss # set default optimizers if optimizer_f is None: optimizer_f = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8) if optimizer_g is None: optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8) # set default neural architectures if neural_f is None: neural_f = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) if neural_g is None: neural_g = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) self.neural_f = neural_f self.neural_g = neural_g # set optimizer and networks self.setup( rng, neural_f, neural_g, dim_data, optimizer_f, optimizer_g, )
[docs] def setup( self, rng: jax.random.PRNGKeyArray, neural_f: models.ModelBase, neural_g: models.ModelBase, dim_data: int, optimizer_f: optax.OptState, optimizer_g: optax.OptState, ) -> None: """Setup all components required to train the network.""" # split random number generator rng, rng_f, rng_g = jax.random.split(rng, 3) # check setting of network architectures warn_str = f"Setting of ICNN and the positive weights setting of the " \ f"`W2NeuralDual` are not consistent. Proceeding with " \ f"the `W2NeuralDual` setting, with positive weights " \ f"being {self.pos_weights}." if isinstance( neural_f, models.ICNN ) and neural_f.pos_weights is not self.pos_weights: warnings.warn(warn_str, stacklevel=2) neural_f.pos_weights = self.pos_weights if isinstance( neural_g, models.ICNN ) and neural_g.pos_weights is not self.pos_weights: warnings.warn(warn_str, stacklevel=2) neural_g.pos_weights = self.pos_weights self.state_f = neural_f.create_train_state( rng_f, optimizer_f, dim_data, ) self.state_g = neural_g.create_train_state( rng_g, optimizer_g, dim_data, ) # default to using back_and_forth with the non-convex models if self.back_and_forth is None: self.back_and_forth = isinstance(neural_f, models.MLP) if self.num_inner_iters == 1 and self.parallel_updates: self.train_step_parallel = self.get_step_fn( train=True, to_optimize="both" ) self.valid_step_parallel = self.get_step_fn( train=False, to_optimize="both" ) self.train_fn = self.train_neuraldual_parallel else: if self.parallel_updates: warnings.warn( "parallel_updates set to True but disabling it " "because num_inner_iters>1", stacklevel=2 ) if self.back_and_forth: raise NotImplementedError( "back_and_forth not implemented without parallel updates" ) self.train_step_f = self.get_step_fn(train=True, to_optimize="f") self.valid_step_f = self.get_step_fn(train=False, to_optimize="f") self.train_step_g = self.get_step_fn(train=True, to_optimize="g") self.valid_step_g = self.get_step_fn(train=False, to_optimize="g") self.train_fn = self.train_neuraldual_alternating
def __call__( # noqa: D102 self, trainloader_source: Iterator[jnp.ndarray], trainloader_target: Iterator[jnp.ndarray], validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, Train_t]]: logs = self.train_fn( trainloader_source, trainloader_target, validloader_source, validloader_target, callback=callback, ) res = self.to_dual_potentials() return (res, logs) if self.logging else res
[docs] def train_neuraldual_parallel( self, trainloader_source: Iterator[jnp.ndarray], trainloader_target: Iterator[jnp.ndarray], validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with parallel updates.""" try: from tqdm.auto import tqdm except ImportError: tqdm = lambda _: _ # define dict to contain source and target batch train_batch, valid_batch = {}, {} # set logging dictionaries train_logs = {"loss_f": [], "loss_g": [], "w_dist": [], "directions": []} valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []} for step in tqdm(range(self.num_train_iters)): update_forward = not self.back_and_forth or step % 2 == 0 if update_forward: train_batch["source"] = jnp.asarray(next(trainloader_source)) train_batch["target"] = jnp.asarray(next(trainloader_target)) (self.state_f, self.state_g, loss, loss_f, loss_g, w_dist) = self.train_step_parallel( self.state_f, self.state_g, train_batch, ) else: train_batch["target"] = jnp.asarray(next(trainloader_source)) train_batch["source"] = jnp.asarray(next(trainloader_target)) (self.state_g, self.state_f, loss, loss_f, loss_g, w_dist) = self.train_step_parallel( self.state_g, self.state_f, train_batch, ) if self.logging and step % self.log_freq == 0: self._update_logs(train_logs, loss_f, loss_g, w_dist) train_logs["directions"].append( "forward" if update_forward else "backward" ) if callback is not None: _ = callback(step, self.to_dual_potentials()) if not self.pos_weights: # Only clip the weights of the f network self.state_f = self.state_f.replace( params=self._clip_weights_icnn(self.state_f.params) ) # report the loss on an validation dataset periodically if step != 0 and step % self.valid_freq == 0: # get batch valid_batch["source"] = jnp.asarray(next(validloader_source)) valid_batch["target"] = jnp.asarray(next(validloader_target)) valid_loss_f, valid_loss_g, valid_w_dist = self.valid_step_parallel( self.state_f, self.state_g, valid_batch, ) if self.logging: self._update_logs( valid_logs, valid_loss_f, valid_loss_g, valid_w_dist ) return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs] def train_neuraldual_alternating( self, trainloader_source: Iterator[jnp.ndarray], trainloader_target: Iterator[jnp.ndarray], validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with alternating updates.""" try: from tqdm.auto import tqdm except ImportError: tqdm = lambda _: _ # define dict to contain source and target batch batch_g, batch_f, valid_batch = {}, {}, {} # set logging dictionaries train_logs = {"loss_f": [], "loss_g": [], "w_dist": []} valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []} for step in tqdm(range(self.num_train_iters)): # execute training steps for _ in range(self.num_inner_iters): # get train batch for potential g batch_g["source"] = jnp.asarray(next(trainloader_source)) batch_g["target"] = jnp.asarray(next(trainloader_target)) self.state_g, loss_g, _ = self.train_step_g( self.state_f, self.state_g, batch_g ) # get train batch for potential f batch_f["source"] = jnp.asarray(next(trainloader_source)) batch_f["target"] = jnp.asarray(next(trainloader_target)) self.state_f, loss_f, w_dist = self.train_step_f( self.state_f, self.state_g, batch_f ) if not self.pos_weights: # Only clip the weights of the f network self.state_f = self.state_f.replace( params=self._clip_weights_icnn(self.state_f.params) ) if callback is not None: callback(step, self.to_dual_potentials()) if self.logging and step % self.log_freq == 0: self._update_logs(train_logs, loss_f, loss_g, w_dist) # report the loss on validation dataset periodically if step != 0 and step % self.valid_freq == 0: # get batch valid_batch["source"] = jnp.asarray(next(validloader_source)) valid_batch["target"] = jnp.asarray(next(validloader_target)) valid_loss_f, _ = self.valid_step_f( self.state_f, self.state_g, valid_batch ) valid_loss_g, valid_w_dist = self.valid_step_g( self.state_f, self.state_g, valid_batch ) if self.logging: self._update_logs( valid_logs, valid_loss_f, valid_loss_g, valid_w_dist ) return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs] def get_step_fn( self, train: bool, to_optimize: Literal["f", "g", "parallel", "both"] ): """Create a parallel training and evaluation function.""" def loss_fn(params_f, params_g, f_value, g_value, g_gradient, batch): """Loss function for both potentials.""" # get two distributions source, target = batch["source"], batch["target"] init_source_hat = g_gradient(params_g)(target) def g_value_partial(y: jnp.ndarray) -> jnp.ndarray: """Lazy way of evaluating g if f's computation needs it.""" return g_value(params_g)(y) f_value_partial = f_value(params_f, g_value_partial) if self.conjugate_solver is not None: finetune_source_hat = lambda y, x_init: self.conjugate_solver.solve( f_value_partial, y, x_init=x_init ).grad finetune_source_hat = jax.vmap(finetune_source_hat) source_hat_detach = jax.lax.stop_gradient( finetune_source_hat(target, init_source_hat) ) else: source_hat_detach = init_source_hat batch_dot = jax.vmap(jnp.dot) f_source = f_value_partial(source) f_star_target = batch_dot(source_hat_detach, target) - f_value_partial(source_hat_detach) dual_source = f_source.mean() dual_target = f_star_target.mean() dual_loss = dual_source + dual_target if self.amortization_loss == "regression": amor_loss = ((init_source_hat - source_hat_detach) ** 2).mean() elif self.amortization_loss == "objective": f_value_parameters_detached = f_value( jax.lax.stop_gradient(params_f), g_value_partial ) amor_loss = ( f_value_parameters_detached(init_source_hat) - batch_dot(init_source_hat, target) ).mean() else: raise ValueError("Amortization loss has been misspecified.") if to_optimize == "both": loss = dual_loss + amor_loss elif to_optimize == "f": loss = dual_loss elif to_optimize == "g": loss = amor_loss else: raise ValueError( f"Optimization target {to_optimize} has been misspecified." ) if self.pos_weights: # Penalize the weights of both networks, even though one # of them will be exactly clipped. # Having both here is necessary in case this is being called with # the potentials reversed with the back_and_forth. loss += self.beta * self._penalize_weights_icnn(params_f) + \ self.beta * self._penalize_weights_icnn(params_g) # compute Wasserstein-2 distance C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \ jnp.mean(jnp.sum(target ** 2, axis=-1)) W2_dist = C - 2. * (f_source.mean() + f_star_target.mean()) return loss, (dual_loss, amor_loss, W2_dist) @jax.jit def step_fn(state_f, state_g, batch): """Step function of either training or validation.""" grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True) if train: # compute loss and gradients (loss, (loss_f, loss_g, W2_dist)), (grads_f, grads_g) = grad_fn( state_f.params, state_g.params, state_f.potential_value_fn, state_g.potential_value_fn, state_g.potential_gradient_fn, batch, ) # update state if to_optimize == "both": return ( state_f.apply_gradients(grads=grads_f), state_g.apply_gradients(grads=grads_g), loss, loss_f, loss_g, W2_dist ) if to_optimize == "f": return state_f.apply_gradients(grads=grads_f), loss_f, W2_dist if to_optimize == "g": return state_g.apply_gradients(grads=grads_g), loss_g, W2_dist raise ValueError("Optimization target has been misspecified.") # compute loss and gradients (loss, (loss_f, loss_g, W2_dist)), _ = grad_fn( state_f.params, state_g.params, state_f.potential_value_fn, state_g.potential_value_fn, state_g.potential_gradient_fn, batch, ) # do not update state if to_optimize == "both": return loss_f, loss_g, W2_dist if to_optimize == "f": return loss_f, W2_dist if to_optimize == "g": return loss_g, W2_dist raise ValueError("Optimization target has been misspecified.") return step_fn
[docs] def to_dual_potentials( self, finetune_g: bool = True ) -> potentials.DualPotentials: r"""Return the Kantorovich dual potentials from the trained potentials. Args: finetune_g: Run the conjugate solver to fine-tune the prediction. Returns: A dual potential object """ f_value = self.state_f.potential_value_fn(self.state_f.params) g_value_prediction = self.state_g.potential_value_fn( self.state_g.params, f_value ) def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray: x_hat = jax.grad(g_value_prediction)(y) grad_g_y = jax.lax.stop_gradient( self.conjugate_solver.solve(f_value, y, x_init=x_hat).grad ) return -f_value(grad_g_y) + jnp.dot(grad_g_y, y) return potentials.DualPotentials( f=f_value, g=g_value_prediction if not finetune_g or self.conjugate_solver is None else g_value_finetuned, cost_fn=costs.SqEuclidean(), corr=True )
@staticmethod def _clip_weights_icnn(params): params = params.unfreeze() for k in params: if k.startswith("w_z"): params[k]["kernel"] = jnp.clip(params[k]["kernel"], a_min=0) return core.freeze(params) @staticmethod def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float: penalty = 0.0 for k, param in params.items(): if k.startswith("w_z"): penalty += jnp.linalg.norm(jax.nn.relu(-param["kernel"])) return penalty @staticmethod def _update_logs( logs: Dict[str, List[Union[float, str]]], loss_f: jnp.ndarray, loss_g: jnp.ndarray, w_dist: jnp.ndarray, ) -> None: logs["loss_f"].append(float(loss_f)) logs["loss_g"].append(float(loss_g)) logs["w_dist"].append(float(w_dist))