Source code for ott.neural.methods.neuraldual

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import (
    Callable,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
)

import jax
import jax.numpy as jnp

import optax

from ott import utils
from ott.geometry import costs
from ott.neural.networks import icnn, potentials
from ott.neural.networks.layers import conjugate
from ott.problems.linear import potentials as dual_potentials

__all__ = ["W2NeuralDual"]

Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]]
Callback_t = Callable[[int, dual_potentials.DualPotentials], None]


[docs] class W2NeuralDual: r"""Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces. Learn the Wasserstein-2 optimal transport between two measures :math:`\alpha` and :math:`\beta` in :math:`n`-dimensional Euclidean space, denoted source and target, respectively. This is achieved by parameterizing a Kantorovich potential :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}` associated with the :math:`\alpha` measure with an :class:`~ott.neural.networks.icnn.ICNN` or a :class:`~ott.neural.networks.potentials.PotentialMLP`, where :math:`\nabla f` transports source to target cells. This potential is learned by optimizing the dual form associated with the negative inner product cost .. math:: \text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] - \mathbb{E}_{y\sim\beta}[f^\star_\theta(y)], where :math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle` is the convex conjugate. :math:`\nabla f^\star` transports from the target to source cells and provides the inverse optimal transport map from :math:`\beta` to :math:`\alpha`. This solver estimates the conjugate :math:`f^\star` with a neural approximation :math:`g` that is fine-tuned with :class:`~ott.neural.networks.layers.conjugate.FenchelConjugateSolver`, which is a combination further described in :cite:`amos:23`. The :class:`~ott.neural.networks.potentials.BasePotential` potentials for ``neural_f`` and ``neural_g`` can 1. both provide the values of the potentials :math:`f` and :math:`g`, or 2. one of them can provide the gradient mapping e.g., :math:`\nabla f` or :math:`\nabla g` where the potential's value can be obtained via the Fenchel conjugate as discussed in :cite:`amos:23`. The potential's value or gradient mapping is specified via :attr:`~ott.neural.networks.potentials.BasePotential.is_potential`. Args: dim_data: input dimensionality of data required for network init neural_f: network architecture for potential :math:`f`. neural_g: network architecture for the conjugate potential :math:`g\approx f^\star` optimizer_f: optimizer function for potential :math:`f` optimizer_g: optimizer function for the conjugate potential :math:`g` num_train_iters: number of total training iterations num_inner_iters: number of training iterations of :math:`g` per iteration of :math:`f` back_and_forth: alternate between updating the forward and backward directions. Inspired from :cite:`jacobs:20` valid_freq: frequency with which model is validated log_freq: frequency with training and validation are logged logging: option to return logs rng: random key used for seeding for network initializations pos_weights: option to train networks with positive weights or regularizer beta: regularization parameter when not training with positive weights conjugate_solver: numerical solver for the Fenchel conjugate. amortization_loss: amortization loss for the conjugate :math:`g\approx f^\star`. Options are `'objective'` :cite:`makkuva:20` or `'regression'` :cite:`amos:23`. parallel_updates: Update :math:`f` and :math:`g` at the same time """ def __init__( self, dim_data: int, neural_f: Optional[potentials.BasePotential] = None, neural_g: Optional[potentials.BasePotential] = None, optimizer_f: Optional[optax.OptState] = None, optimizer_g: Optional[optax.OptState] = None, num_train_iters: int = 20000, num_inner_iters: int = 1, back_and_forth: Optional[bool] = None, valid_freq: int = 1000, log_freq: int = 1000, logging: bool = False, rng: Optional[jax.Array] = None, pos_weights: bool = True, beta: float = 1.0, conjugate_solver: Optional[conjugate.FenchelConjugateSolver ] = conjugate.DEFAULT_CONJUGATE_SOLVER, amortization_loss: Literal["objective", "regression"] = "regression", parallel_updates: bool = True, ): self.num_train_iters = num_train_iters self.num_inner_iters = num_inner_iters self.back_and_forth = back_and_forth self.valid_freq = valid_freq self.log_freq = log_freq self.logging = logging self.pos_weights = pos_weights self.beta = beta self.parallel_updates = parallel_updates self.conjugate_solver = conjugate_solver self.amortization_loss = amortization_loss # set default optimizers if optimizer_f is None: optimizer_f = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8) if optimizer_g is None: optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8) # set default neural architectures if neural_f is None: neural_f = icnn.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) if neural_g is None: neural_g = icnn.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) self.neural_f = neural_f self.neural_g = neural_g # set optimizer and networks self.setup( utils.default_prng_key(rng), neural_f, neural_g, dim_data, optimizer_f, optimizer_g, )
[docs] def setup( self, rng: jax.Array, neural_f: potentials.BasePotential, neural_g: potentials.BasePotential, dim_data: int, optimizer_f: optax.OptState, optimizer_g: optax.OptState, ) -> None: """Setup all components required to train the network.""" # split random number generator rng, rng_f, rng_g = jax.random.split(rng, 3) # check setting of network architectures warn_str = f"Setting of ICNN and the positive weights setting of the " \ f"`W2NeuralDual` are not consistent. Proceeding with " \ f"the `W2NeuralDual` setting, with positive weights " \ f"being {self.pos_weights}." if isinstance( neural_f, icnn.ICNN ) and neural_f.pos_weights is not self.pos_weights: warnings.warn(warn_str, stacklevel=2) neural_f.pos_weights = self.pos_weights if isinstance( neural_g, icnn.ICNN ) and neural_g.pos_weights is not self.pos_weights: warnings.warn(warn_str, stacklevel=2) neural_g.pos_weights = self.pos_weights self.state_f = neural_f.create_train_state( rng_f, optimizer_f, (1, dim_data), # also include the batch dimension ) self.state_g = neural_g.create_train_state( rng_g, optimizer_g, (1, dim_data), ) # default to using back_and_forth with the non-convex models if self.back_and_forth is None: self.back_and_forth = isinstance(neural_f, potentials.PotentialMLP) if self.num_inner_iters == 1 and self.parallel_updates: self.train_step_parallel = self.get_step_fn( train=True, to_optimize="both" ) self.valid_step_parallel = self.get_step_fn( train=False, to_optimize="both" ) self.train_fn = self.train_neuraldual_parallel else: if self.parallel_updates: warnings.warn( "parallel_updates set to True but disabling it " "because num_inner_iters>1", stacklevel=2 ) if self.back_and_forth: raise NotImplementedError( "back_and_forth not implemented without parallel updates" ) self.train_step_f = self.get_step_fn(train=True, to_optimize="f") self.valid_step_f = self.get_step_fn(train=False, to_optimize="f") self.train_step_g = self.get_step_fn(train=True, to_optimize="g") self.valid_step_g = self.get_step_fn(train=False, to_optimize="g") self.train_fn = self.train_neuraldual_alternating
def __call__( # noqa: D102 self, trainloader_source: Iterator[jnp.ndarray], trainloader_target: Iterator[jnp.ndarray], validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Union[dual_potentials.DualPotentials, Tuple[dual_potentials.DualPotentials, Train_t]]: logs = self.train_fn( trainloader_source, trainloader_target, validloader_source, validloader_target, callback=callback, ) res = self.to_dual_potentials() return (res, logs) if self.logging else res
[docs] def train_neuraldual_parallel( self, trainloader_source: Iterator[jnp.ndarray], trainloader_target: Iterator[jnp.ndarray], validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with parallel updates.""" try: from tqdm.auto import tqdm except ImportError: tqdm = lambda _: _ # define dict to contain source and target batch train_batch, valid_batch = {}, {} # set logging dictionaries train_logs = {"loss_f": [], "loss_g": [], "w_dist": [], "directions": []} valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []} for step in tqdm(range(self.num_train_iters)): update_forward = not self.back_and_forth or step % 2 == 0 if update_forward: train_batch["source"] = jnp.asarray(next(trainloader_source)) train_batch["target"] = jnp.asarray(next(trainloader_target)) (self.state_f, self.state_g, loss, loss_f, loss_g, w_dist) = self.train_step_parallel( self.state_f, self.state_g, train_batch, ) else: train_batch["target"] = jnp.asarray(next(trainloader_source)) train_batch["source"] = jnp.asarray(next(trainloader_target)) (self.state_g, self.state_f, loss, loss_f, loss_g, w_dist) = self.train_step_parallel( self.state_g, self.state_f, train_batch, ) if self.logging and step % self.log_freq == 0: self._update_logs(train_logs, loss_f, loss_g, w_dist) train_logs["directions"].append( "forward" if update_forward else "backward" ) if callback is not None: _ = callback(step, self.to_dual_potentials()) if not self.pos_weights: # Only clip the weights of the f network self.state_f = self.state_f.replace( params=self._clip_weights_icnn(self.state_f.params) ) # report the loss on an validation dataset periodically if step != 0 and step % self.valid_freq == 0: # get batch valid_batch["source"] = jnp.asarray(next(validloader_source)) valid_batch["target"] = jnp.asarray(next(validloader_target)) valid_loss_f, valid_loss_g, valid_w_dist = self.valid_step_parallel( self.state_f, self.state_g, valid_batch, ) if self.logging: self._update_logs( valid_logs, valid_loss_f, valid_loss_g, valid_w_dist ) return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs] def train_neuraldual_alternating( self, trainloader_source: Iterator[jnp.ndarray], trainloader_target: Iterator[jnp.ndarray], validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with alternating updates.""" try: from tqdm.auto import tqdm except ImportError: tqdm = lambda _: _ # define dict to contain source and target batch batch_g, batch_f, valid_batch = {}, {}, {} # set logging dictionaries train_logs = {"loss_f": [], "loss_g": [], "w_dist": []} valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []} for step in tqdm(range(self.num_train_iters)): # execute training steps for _ in range(self.num_inner_iters): # get train batch for potential g batch_g["source"] = jnp.asarray(next(trainloader_source)) batch_g["target"] = jnp.asarray(next(trainloader_target)) self.state_g, loss_g, _ = self.train_step_g( self.state_f, self.state_g, batch_g ) # get train batch for potential f batch_f["source"] = jnp.asarray(next(trainloader_source)) batch_f["target"] = jnp.asarray(next(trainloader_target)) self.state_f, loss_f, w_dist = self.train_step_f( self.state_f, self.state_g, batch_f ) if not self.pos_weights: # Only clip the weights of the f network self.state_f = self.state_f.replace( params=self._clip_weights_icnn(self.state_f.params) ) if callback is not None: callback(step, self.to_dual_potentials()) if self.logging and step % self.log_freq == 0: self._update_logs(train_logs, loss_f, loss_g, w_dist) # report the loss on validation dataset periodically if step != 0 and step % self.valid_freq == 0: # get batch valid_batch["source"] = jnp.asarray(next(validloader_source)) valid_batch["target"] = jnp.asarray(next(validloader_target)) valid_loss_f, _ = self.valid_step_f( self.state_f, self.state_g, valid_batch ) valid_loss_g, valid_w_dist = self.valid_step_g( self.state_f, self.state_g, valid_batch ) if self.logging: self._update_logs( valid_logs, valid_loss_f, valid_loss_g, valid_w_dist ) return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs] def get_step_fn( self, train: bool, to_optimize: Literal["f", "g", "parallel", "both"] ): """Create a parallel training and evaluation function.""" def loss_fn(params_f, params_g, f_value, g_value, g_gradient, batch): """Loss function for both potentials.""" # get two distributions source, target = batch["source"], batch["target"] init_source_hat = g_gradient(params_g)(target) def g_value_partial(y: jnp.ndarray) -> jnp.ndarray: """Lazy way of evaluating g if f's computation needs it.""" return g_value(params_g)(y) f_value_partial = f_value(params_f, g_value_partial) if self.conjugate_solver is not None: finetune_source_hat = lambda y, x_init: self.conjugate_solver.solve( f_value_partial, y, x_init=x_init ).grad finetune_source_hat = jax.vmap(finetune_source_hat) source_hat_detach = jax.lax.stop_gradient( finetune_source_hat(target, init_source_hat) ) else: source_hat_detach = init_source_hat batch_dot = jax.vmap(jnp.dot) f_source = f_value_partial(source) f_star_target = batch_dot(source_hat_detach, target) - f_value_partial(source_hat_detach) dual_source = f_source.mean() dual_target = f_star_target.mean() dual_loss = dual_source + dual_target if self.amortization_loss == "regression": amor_loss = ((init_source_hat - source_hat_detach) ** 2).mean() elif self.amortization_loss == "objective": f_value_parameters_detached = f_value( jax.lax.stop_gradient(params_f), g_value_partial ) amor_loss = ( f_value_parameters_detached(init_source_hat) - batch_dot(init_source_hat, target) ).mean() else: raise ValueError("Amortization loss has been misspecified.") if to_optimize == "both": loss = dual_loss + amor_loss elif to_optimize == "f": loss = dual_loss elif to_optimize == "g": loss = amor_loss else: raise ValueError( f"Optimization target {to_optimize} has been misspecified." ) if not self.pos_weights: # Penalize the weights of both networks, even though one # of them will be exactly clipped. # Having both here is necessary in case this is being called with # the potentials reversed with the back_and_forth. loss += self.beta * self._penalize_weights_icnn(params_f) + \ self.beta * self._penalize_weights_icnn(params_g) # compute Wasserstein-2 distance C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \ jnp.mean(jnp.sum(target ** 2, axis=-1)) W2_dist = C - 2.0 * (f_source.mean() + f_star_target.mean()) return loss, (dual_loss, amor_loss, W2_dist) @jax.jit def step_fn(state_f, state_g, batch): """Step function of either training or validation.""" grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True) if train: # compute loss and gradients (loss, (loss_f, loss_g, W2_dist)), (grads_f, grads_g) = grad_fn( state_f.params, state_g.params, state_f.potential_value_fn, state_g.potential_value_fn, state_g.potential_gradient_fn, batch, ) # update state if to_optimize == "both": return ( state_f.apply_gradients(grads=grads_f), state_g.apply_gradients(grads=grads_g), loss, loss_f, loss_g, W2_dist ) if to_optimize == "f": return state_f.apply_gradients(grads=grads_f), loss_f, W2_dist if to_optimize == "g": return state_g.apply_gradients(grads=grads_g), loss_g, W2_dist raise ValueError("Optimization target has been misspecified.") # compute loss and gradients (loss, (loss_f, loss_g, W2_dist)), _ = grad_fn( state_f.params, state_g.params, state_f.potential_value_fn, state_g.potential_value_fn, state_g.potential_gradient_fn, batch, ) # do not update state if to_optimize == "both": return loss_f, loss_g, W2_dist if to_optimize == "f": return loss_f, W2_dist if to_optimize == "g": return loss_g, W2_dist raise ValueError("Optimization target has been misspecified.") return step_fn
[docs] def to_dual_potentials( self, finetune_g: bool = True ) -> dual_potentials.DualPotentials: r"""Return the Kantorovich dual potentials from the trained potentials. Args: finetune_g: Run the conjugate solver to fine-tune the prediction. Returns: A dual potential object """ f_value = self.state_f.potential_value_fn(self.state_f.params) g_value_prediction = self.state_g.potential_value_fn( self.state_g.params, f_value ) def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray: x_hat = jax.grad(g_value_prediction)(y) grad_g_y = jax.lax.stop_gradient( self.conjugate_solver.solve(f_value, y, x_init=x_hat).grad ) return -f_value(grad_g_y) + jnp.dot(grad_g_y, y) return dual_potentials.DualPotentials( f=f_value, g=g_value_prediction if not finetune_g or self.conjugate_solver is None else g_value_finetuned, cost_fn=costs.SqEuclidean(), corr=True )
@staticmethod def _clip_weights_icnn(params): for k in params: if k.startswith("w_z"): params[k]["kernel"] = jnp.clip(params[k]["kernel"], a_min=0) return params @staticmethod def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float: penalty = 0.0 for k, param in params.items(): if k.startswith("w_z"): penalty += jnp.linalg.norm(jax.nn.relu(-param["kernel"])) return penalty @staticmethod def _update_logs( logs: Dict[str, List[Union[float, str]]], loss_f: jnp.ndarray, loss_g: jnp.ndarray, w_dist: jnp.ndarray, ) -> None: logs["loss_f"].append(float(loss_f)) logs["loss_g"].append(float(loss_g)) logs["w_dist"].append(float(w_dist))