Source code for ott.initializers.nn.initializers

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.core import frozen_dict
from flax.training import train_state

from ott.geometry import geometry
from ott.initializers.linear import initializers

if TYPE_CHECKING:
  from ott.problems.linear import linear_problem

# TODO(michalk8): add initializer for NeuralDual?
__all__ = ["MetaInitializer", "MetaMLP"]


[docs]@jax.tree_util.register_pytree_node_class class MetaInitializer(initializers.DefaultInitializer): """Meta OT Initializer with a fixed geometry :cite:`amos:22`. This initializer consists of a predictive model that outputs the :math:`f` duals to solve the entropy-regularized OT problem given input probability weights ``a`` and ``b``, and a given (assumed to be fixed) geometry ``geom``. The model's parameters are learned using a training set of OT instances (multiple pairs of probability weights), that assume the **same** geometry ``geom`` is used throughout, both for training and evaluation. The meta model defaults to the MLP in :class:`~ott.initializers.nn.initializers.MetaMLP` and, with batched problem instances passed into :meth:`update`. Args: geom: The fixed geometry of the problem instances. meta_model: The model to predict the potential :math:`f` from the measures. opt: The optimizer to update the parameters. If ``None``, use :func:`optax.adam` with :math:`0.001` learning rate. rng: The PRNG key to use for initializing the model. state: The training state of the model to start from. Examples: The following code shows a simple example of using ``update`` to train the model, where ``a`` and ``b`` are the weights of the measures and ``geom`` is the fixed geometry. .. code-block:: python meta_initializer = init_lib.MetaInitializer(geom) while training(): a, b = sample_batch() loss, init_f, meta_initializer.state = meta_initializer.update( meta_initializer.state, a=a, b=b ) """ def __init__( self, geom: geometry.Geometry, meta_model: Optional[nn.Module] = None, opt: Optional[optax.GradientTransformation ] = optax.adam(learning_rate=1e-3), # noqa: B008 rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), state: Optional[train_state.TrainState] = None ): self.geom = geom self.dtype = geom.x.dtype self.opt = opt self.rng = rng na, nb = geom.shape self.meta_model = MetaMLP( potential_size=na ) if meta_model is None else meta_model if state is None: # Initialize the model's training state. a_placeholder = jnp.zeros(na, dtype=self.dtype) b_placeholder = jnp.zeros(nb, dtype=self.dtype) params = self.meta_model.init(rng, a_placeholder, b_placeholder)["params"] self.state = train_state.TrainState.create( apply_fn=self.meta_model.apply, params=params, tx=opt ) else: self.state = state self.update_impl = self._get_update_fn()
[docs] def update( self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: r"""Update the meta model with the dual objective. The goal is for the model to match the optimal duals, i.e., :math:`\hat f_\theta \approx f^\star`. This can be done by training the predictions of :math:`\hat f_\theta` to optimize the dual objective, which :math:`f^\star` also optimizes for. The overall learning setup can thus be written as: .. math:: \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; J(\hat f_\theta(a, b); \alpha, \beta), where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta` ,:math:`\mathcal{D}` is a meta distribution of optimal transport problems, .. math:: -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\} \right\rangle is the entropic dual objective, and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. Args: state: Optimizer state of the meta model. a: Probabilities of the :math:`\alpha` measure's atoms. b: Probabilities of the :math:`\beta` measure's atoms. Returns: The training loss, :math:`f`, and updated state. """ return self.update_impl(state, a, b)
[docs] def init_dual_a( # noqa: D102 self, ot_prob: "linear_problem.LinearProblem", lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: del rng # Detect if the problem is batched. assert ot_prob.a.ndim in (1, 2) assert ot_prob.b.ndim in (1, 2) vmap_a_val = 0 if ot_prob.a.ndim == 2 else None vmap_b_val = 0 if ot_prob.b.ndim == 2 else None if vmap_a_val is not None or vmap_b_val is not None: compute_f_maybe_batch = jax.vmap( self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None) ) else: compute_f_maybe_batch = self._compute_f init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) return init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f)
def _get_update_fn(self): """Return the implementation (and jitted) update function.""" from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn def dual_obj_loss_single(params, a, b): f_pred = self._compute_f(a, b, params) g_pred = self.geom.update_potential( f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0 ) g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.) ot_prob = linear_problem.LinearProblem(geom=self.geom, a=a, b=b) dual_obj = sinkhorn.ent_reg_cost(f_pred, g_pred, ot_prob, lse_mode=True) loss = -dual_obj return loss, f_pred def loss_batch(params, a, b): loss_fn = functools.partial(dual_obj_loss_single, params=params) loss, f_pred = jax.vmap(loss_fn)(a=a, b=b) return jnp.mean(loss), f_pred @jax.jit def update(state, a, b): a = jnp.atleast_2d(a) b = jnp.atleast_2d(b) grad_fn = jax.value_and_grad(loss_batch, has_aux=True) (loss, init_f), grads = grad_fn(state.params, a, b) return loss, init_f, state.apply_gradients(grads=grads) return update def _compute_f( self, a: jnp.ndarray, b: jnp.ndarray, params: frozen_dict.FrozenDict[str, jnp.ndarray] ) -> jnp.ndarray: r"""Predict the optimal :math:`f` potential. Args: a: Probabilities of the :math:`\alpha` measure's atoms. b: Probabilities of the :math:`\beta` measure's atoms. params: The parameters of the Meta model. Returns: The :math:`f` potential. """ return self.meta_model.apply({"params": params}, a, b) def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.geom, self.meta_model, self.opt], { "rng": self.rng, "state": self.state }
[docs]class MetaMLP(nn.Module): r"""Potential for :class:`~ott.initializers.nn.initializers.MetaInitializer`. This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the probabilities of the measures to the optimal dual potentials :math:`f`. Args: potential_size: The dimensionality of :math:`f`. num_hidden_units: The number of hidden units in each layer. num_hidden_layers: The number of hidden layers. """ potential_size: int num_hidden_units: int = 512 num_hidden_layers: int = 3 @nn.compact def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: r"""Make a prediction. Args: a: Probabilities of the :math:`\alpha` measure's atoms. b: Probabilities of the :math:`\beta` measure's atoms. Returns: The :math:`f` potential. """ dtype = a.dtype z = jnp.concatenate((a, b)) for _ in range(self.num_hidden_layers): z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) return nn.Dense(self.potential_size, dtype=dtype)(z)