# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import (
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from ott import utils
from ott.geometry import costs
from ott.neural.networks import icnn, potentials
from ott.neural.networks.layers import conjugate
from ott.problems.linear import potentials as dual_potentials
__all__ = ["W2NeuralDual"]
Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]]
Callback_t = Callable[[int, dual_potentials.DualPotentials], None]
PotentialValueFn_t = potentials.PotentialValueFn_t
PotentialGradientFn_t = potentials.PotentialGradientFn_t
def _value_fn(
model: nnx.Module,
other_value_fn: Optional[Callable] = None,
) -> PotentialValueFn_t:
"""Get a scalar value function from an NNX model.
For potential models (``is_potential=True``), returns ``model(x)``.
For gradient models, reconstructs via the envelope theorem.
"""
if model.is_potential:
return lambda x: model(x)
assert other_value_fn is not None, (
"The value of a gradient-based potential depends on the other potential."
)
def value_fn(x: jnp.ndarray) -> jnp.ndarray:
squeeze = x.ndim == 1
if squeeze:
x = jnp.expand_dims(x, 0)
grad_g_x = jax.lax.stop_gradient(model(x))
value = -other_value_fn(grad_g_x) + jax.vmap(jnp.dot)(grad_g_x, x)
return value.squeeze(0) if squeeze else value
return value_fn
def _gradient_fn(model: nnx.Module) -> PotentialGradientFn_t:
"""Get a gradient function from an NNX model.
For potential models, returns ``vmap(grad(model))``.
For gradient models, returns ``model`` directly.
"""
if model.is_potential:
return jax.vmap(jax.grad(lambda x: model(x)))
return lambda x: model(x)
[docs]
class W2NeuralDual:
r"""Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.
Learn the Wasserstein-2 optimal transport between two measures
:math:`\alpha` and :math:`\beta` in :math:`n`-dimensional Euclidean space,
denoted source and target, respectively. This is achieved by parameterizing
a Kantorovich potential :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}`
associated with the :math:`\alpha` measure with an
:class:`~ott.neural.networks.icnn.ICNN` or a
:class:`~ott.neural.networks.potentials.PotentialMLP`, where
:math:`\nabla f` transports source to target cells. This potential is learned
by optimizing the dual form associated with the negative inner product cost
.. math::
\text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] -
\mathbb{E}_{y\sim\beta}[f^\star_\theta(y)],
where :math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle`
is the convex conjugate.
:math:`\nabla f^\star` transports from the target
to source cells and provides the inverse optimal
transport map from :math:`\beta` to :math:`\alpha`.
This solver estimates the conjugate :math:`f^\star`
with a neural approximation :math:`g` that is fine-tuned
with :class:`~ott.neural.networks.layers.conjugate.FenchelConjugateSolver`,
which is a combination further described in :cite:`amos:23`.
The potentials for ``neural_f`` and ``neural_g`` can
1. both provide the values of the potentials :math:`f` and :math:`g`, or
2. one of them can provide the gradient mapping e.g., :math:`\nabla f`
or :math:`\nabla g` where the potential's value can be obtained
via the Fenchel conjugate as discussed in :cite:`amos:23`.
The potential's value or gradient mapping is specified via the
``is_potential`` property of the model.
Args:
dim_data: input dimensionality of data required for network init
neural_f: NNX network for potential :math:`f`. Must expose an
``is_potential`` property.
neural_g: NNX network for the conjugate potential
:math:`g\approx f^\star`
optimizer_f: optimizer function for potential :math:`f`
optimizer_g: optimizer function for the conjugate potential :math:`g`
num_train_iters: number of total training iterations
num_inner_iters: number of training iterations of :math:`g` per iteration
of :math:`f`
back_and_forth: alternate between updating the forward and backward
directions. Inspired from :cite:`jacobs:20`
valid_freq: frequency with which model is validated
log_freq: frequency with training and validation are logged
logging: option to return logs
rng: random key used for seeding for network initializations
conjugate_solver: numerical solver for the Fenchel conjugate.
amortization_loss: amortization loss for the conjugate
:math:`g\approx f^\star`. Options are `'objective'` :cite:`makkuva:20` or
`'regression'` :cite:`amos:23`.
parallel_updates: Update :math:`f` and :math:`g` at the same time
"""
def __init__(
self,
dim_data: int,
neural_f: Optional[nnx.Module] = None,
neural_g: Optional[nnx.Module] = None,
optimizer_f: Optional[optax.OptState] = None,
optimizer_g: Optional[optax.OptState] = None,
num_train_iters: int = 20000,
num_inner_iters: int = 1,
back_and_forth: Optional[bool] = None,
valid_freq: int = 1000,
log_freq: int = 1000,
logging: bool = False,
rng: Optional[jax.Array] = None,
conjugate_solver: Optional[conjugate.FenchelConjugateSolver
] = conjugate.DEFAULT_CONJUGATE_SOLVER,
amortization_loss: Literal["objective", "regression"] = "regression",
parallel_updates: bool = True,
):
self.num_train_iters = num_train_iters
self.num_inner_iters = num_inner_iters
self.back_and_forth = back_and_forth
self.valid_freq = valid_freq
self.log_freq = log_freq
self.logging = logging
self.parallel_updates = parallel_updates
self.conjugate_solver = conjugate_solver
self.amortization_loss = amortization_loss
# set default optimizers
if optimizer_f is None:
optimizer_f = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8)
if optimizer_g is None:
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8)
# set default neural architectures
rng = utils.default_prng_key(rng)
rng, rng_init_f, rng_init_g = jax.random.split(rng, 3)
if neural_f is None:
neural_f = icnn.ICNN(
input_dim=dim_data,
dim_hidden=[64, 64, 64, 64],
rngs=nnx.Rngs(rng_init_f),
)
if neural_g is None:
neural_g = icnn.ICNN(
input_dim=dim_data,
dim_hidden=[64, 64, 64, 64],
rngs=nnx.Rngs(rng_init_g),
)
self.neural_f = neural_f
self.neural_g = neural_g
# set optimizers and step functions
self.setup(neural_f, neural_g, optimizer_f, optimizer_g)
[docs]
def setup(
self,
neural_f: nnx.Module,
neural_g: nnx.Module,
optimizer_f: optax.OptState,
optimizer_g: optax.OptState,
) -> None:
"""Setup all components required to train the network."""
self.opt_f = nnx.Optimizer(neural_f, optimizer_f, wrt=nnx.Param)
self.opt_g = nnx.Optimizer(neural_g, optimizer_g, wrt=nnx.Param)
# default to using back_and_forth with the non-convex models
if self.back_and_forth is None:
self.back_and_forth = isinstance(neural_f, potentials.PotentialMLP)
if self.num_inner_iters == 1 and self.parallel_updates:
self.train_step_parallel = self._get_parallel_step_fn(train=True)
self.valid_step_parallel = self._get_parallel_step_fn(train=False)
self.train_fn = self.train_neuraldual_parallel
else:
if self.parallel_updates:
warnings.warn(
"parallel_updates set to True but disabling it "
"because num_inner_iters>1",
stacklevel=2
)
if self.back_and_forth:
raise NotImplementedError(
"back_and_forth not implemented without parallel updates"
)
self.train_step_f = self._get_alternating_step_fn(
train=True, to_optimize="f"
)
self.valid_step_f = self._get_alternating_step_fn(
train=False, to_optimize="f"
)
self.train_step_g = self._get_alternating_step_fn(
train=True, to_optimize="g"
)
self.valid_step_g = self._get_alternating_step_fn(
train=False, to_optimize="g"
)
self.train_fn = self.train_neuraldual_alternating
def __call__( # noqa: D102
self,
trainloader_source: Iterator[jnp.ndarray],
trainloader_target: Iterator[jnp.ndarray],
validloader_source: Iterator[jnp.ndarray],
validloader_target: Iterator[jnp.ndarray],
callback: Optional[Callback_t] = None,
) -> Union[dual_potentials.DualPotentials,
Tuple[dual_potentials.DualPotentials, Train_t]]:
logs = self.train_fn(
trainloader_source,
trainloader_target,
validloader_source,
validloader_target,
callback=callback,
)
res = self.to_dual_potentials()
return (res, logs) if self.logging else res
# ---- loss computation (shared by both parallel and alternating) ---------
def _compute_losses(
self,
model_f: nnx.Module,
model_g: nnx.Module,
batch: Dict[str, jnp.ndarray],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute all losses.
Returns:
``(dual_loss, amor_loss, W2_dist)``
"""
source, target = batch["source"], batch["target"]
g_gradient = _gradient_fn(model_g)
init_source_hat = g_gradient(target)
def g_value_partial(y: jnp.ndarray) -> jnp.ndarray:
return _value_fn(model_g)(y)
f_value_partial = _value_fn(model_f, g_value_partial)
if self.conjugate_solver is not None:
finetune_source_hat = lambda y, x_init: self.conjugate_solver.solve(
f_value_partial, y, x_init=x_init
).grad
finetune_source_hat = jax.vmap(finetune_source_hat)
source_hat_detach = jax.lax.stop_gradient(
finetune_source_hat(target, init_source_hat)
)
else:
source_hat_detach = init_source_hat
batch_dot = jax.vmap(jnp.dot)
f_source = f_value_partial(source)
f_star_target = batch_dot(source_hat_detach,
target) - f_value_partial(source_hat_detach)
dual_source = f_source.mean()
dual_target = f_star_target.mean()
dual_loss = dual_source + dual_target
if self.amortization_loss == "regression":
amor_loss = ((init_source_hat - source_hat_detach) ** 2).mean()
elif self.amortization_loss == "objective":
# Stop gradients through f's parameters only (not inputs)
f_graphdef, f_state = nnx.split(model_f)
f_state_stopped = jax.lax.stop_gradient(f_state)
model_f_detached = nnx.merge(f_graphdef, f_state_stopped)
f_value_detached = _value_fn(model_f_detached, g_value_partial)
amor_loss = (
f_value_detached(init_source_hat) -
batch_dot(init_source_hat, target)
).mean()
else:
raise ValueError("Amortization loss has been misspecified.")
# compute Wasserstein-2 distance
C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \
jnp.mean(jnp.sum(target ** 2, axis=-1))
W2_dist = C - 2.0 * (f_source.mean() + f_star_target.mean())
return dual_loss, amor_loss, W2_dist
# ---- parallel step functions -------------------------------------------
def _get_parallel_step_fn(self, train: bool):
"""Create parallel training/validation step function."""
_diff_both = (nnx.DiffState(0, nnx.Param), nnx.DiffState(1, nnx.Param))
@nnx.jit
def train_step(model_f, model_g, opt_f, opt_g, batch):
def loss_fn_both(model_f, model_g):
dual_loss, amor_loss, _ = self._compute_losses(model_f, model_g, batch)
return dual_loss + amor_loss
# Differentiate w.r.t. both models
loss, (grads_f, grads_g) = nnx.value_and_grad(
loss_fn_both, argnums=_diff_both
)(model_f, model_g)
opt_f.update(model_f, grads_f)
opt_g.update(model_g, grads_g)
# Recompute individual losses for logging
dual_loss, amor_loss, W2_dist = self._compute_losses(
model_f, model_g, batch
)
return loss, dual_loss, amor_loss, W2_dist
@nnx.jit
def valid_step(model_f, model_g, batch):
dual_loss, amor_loss, W2_dist = self._compute_losses(
model_f, model_g, batch
)
return dual_loss, amor_loss, W2_dist
return train_step if train else valid_step
# ---- alternating step functions ----------------------------------------
def _get_alternating_step_fn(
self, train: bool, to_optimize: Literal["f", "g"]
):
"""Create alternating training/validation step function."""
_diff_f = nnx.DiffState(0, nnx.Param)
_diff_g = nnx.DiffState(1, nnx.Param)
@nnx.jit
def train_step_f(model_f, model_g, opt_f, batch):
def loss_fn(model_f, model_g):
dual_loss, _, _ = self._compute_losses(model_f, model_g, batch)
return dual_loss
grads = nnx.grad(loss_fn, argnums=_diff_f)(model_f, model_g)
opt_f.update(model_f, grads)
dual_loss, _, W2_dist = self._compute_losses(model_f, model_g, batch)
return dual_loss, W2_dist
@nnx.jit
def train_step_g(model_f, model_g, opt_g, batch):
def loss_fn(model_f, model_g):
_, amor_loss, _ = self._compute_losses(model_f, model_g, batch)
return amor_loss
grads = nnx.grad(loss_fn, argnums=_diff_g)(model_f, model_g)
opt_g.update(model_g, grads)
_, amor_loss, W2_dist = self._compute_losses(model_f, model_g, batch)
return amor_loss, W2_dist
@nnx.jit
def valid_step(model_f, model_g, batch):
dual_loss, amor_loss, W2_dist = self._compute_losses(
model_f, model_g, batch
)
if to_optimize == "f":
return dual_loss, W2_dist
return amor_loss, W2_dist
if train:
return train_step_f if to_optimize == "f" else train_step_g
return valid_step
# ---- training loops ----------------------------------------------------
[docs]
def train_neuraldual_parallel(
self,
trainloader_source: Iterator[jnp.ndarray],
trainloader_target: Iterator[jnp.ndarray],
validloader_source: Iterator[jnp.ndarray],
validloader_target: Iterator[jnp.ndarray],
callback: Optional[Callback_t] = None,
) -> Train_t:
"""Training and validation with parallel updates."""
try:
from tqdm.auto import tqdm
except ImportError:
tqdm = lambda _: _
# define dict to contain source and target batch
train_batch, valid_batch = {}, {}
# set logging dictionaries
train_logs = {"loss_f": [], "loss_g": [], "w_dist": [], "directions": []}
valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []}
for step in tqdm(range(self.num_train_iters)):
update_forward = not self.back_and_forth or step % 2 == 0
if update_forward:
train_batch["source"] = jnp.asarray(next(trainloader_source))
train_batch["target"] = jnp.asarray(next(trainloader_target))
(loss, loss_f, loss_g, w_dist) = self.train_step_parallel(
self.neural_f,
self.neural_g,
self.opt_f,
self.opt_g,
train_batch,
)
else:
train_batch["target"] = jnp.asarray(next(trainloader_source))
train_batch["source"] = jnp.asarray(next(trainloader_target))
(loss, loss_f, loss_g, w_dist) = self.train_step_parallel(
self.neural_g,
self.neural_f,
self.opt_g,
self.opt_f,
train_batch,
)
if self.logging and step % self.log_freq == 0:
self._update_logs(train_logs, loss_f, loss_g, w_dist)
train_logs["directions"].append(
"forward" if update_forward else "backward"
)
if callback is not None:
_ = callback(step, self.to_dual_potentials())
# report the loss on an validation dataset periodically
if step != 0 and step % self.valid_freq == 0:
# get batch
valid_batch["source"] = jnp.asarray(next(validloader_source))
valid_batch["target"] = jnp.asarray(next(validloader_target))
valid_loss_f, valid_loss_g, valid_w_dist = self.valid_step_parallel(
self.neural_f,
self.neural_g,
valid_batch,
)
if self.logging:
self._update_logs(
valid_logs, valid_loss_f, valid_loss_g, valid_w_dist
)
return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs]
def train_neuraldual_alternating(
self,
trainloader_source: Iterator[jnp.ndarray],
trainloader_target: Iterator[jnp.ndarray],
validloader_source: Iterator[jnp.ndarray],
validloader_target: Iterator[jnp.ndarray],
callback: Optional[Callback_t] = None,
) -> Train_t:
"""Training and validation with alternating updates."""
try:
from tqdm.auto import tqdm
except ImportError:
tqdm = lambda _: _
# define dict to contain source and target batch
batch_g, batch_f, valid_batch = {}, {}, {}
# set logging dictionaries
train_logs = {"loss_f": [], "loss_g": [], "w_dist": []}
valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []}
for step in tqdm(range(self.num_train_iters)):
# execute training steps
for _ in range(self.num_inner_iters):
# get train batch for potential g
batch_g["source"] = jnp.asarray(next(trainloader_source))
batch_g["target"] = jnp.asarray(next(trainloader_target))
loss_g, _ = self.train_step_g(
self.neural_f, self.neural_g, self.opt_g, batch_g
)
# get train batch for potential f
batch_f["source"] = jnp.asarray(next(trainloader_source))
batch_f["target"] = jnp.asarray(next(trainloader_target))
loss_f, w_dist = self.train_step_f(
self.neural_f, self.neural_g, self.opt_f, batch_f
)
if callback is not None:
callback(step, self.to_dual_potentials())
if self.logging and step % self.log_freq == 0:
self._update_logs(train_logs, loss_f, loss_g, w_dist)
# report the loss on validation dataset periodically
if step != 0 and step % self.valid_freq == 0:
# get batch
valid_batch["source"] = jnp.asarray(next(validloader_source))
valid_batch["target"] = jnp.asarray(next(validloader_target))
valid_loss_f, _ = self.valid_step_f(
self.neural_f, self.neural_g, valid_batch
)
valid_loss_g, valid_w_dist = self.valid_step_g(
self.neural_f, self.neural_g, valid_batch
)
if self.logging:
self._update_logs(
valid_logs, valid_loss_f, valid_loss_g, valid_w_dist
)
return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs]
def to_dual_potentials(
self, finetune_g: bool = True
) -> dual_potentials.DualPotentials:
r"""Return the Kantorovich dual potentials from the trained potentials.
Args:
finetune_g: Run the conjugate solver to fine-tune the prediction.
Returns:
A dual potential object
"""
f_value = _value_fn(self.neural_f)
g_value_prediction = _value_fn(self.neural_g, f_value)
def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray:
x_hat = jax.grad(g_value_prediction)(y)
grad_g_y = jax.lax.stop_gradient(
self.conjugate_solver.solve(f_value, y, x_init=x_hat).grad
)
return -f_value(grad_g_y) + jnp.dot(grad_g_y, y)
if not finetune_g or self.conjugate_solver is None:
g_value = g_value_prediction
else:
g_value = g_value_finetuned
# switch from grad-convex potentials to quadratic - convex parameterization
return dual_potentials.DualPotentials(
f=lambda x: 0.5 * jnp.sum(x ** 2) - f_value(x),
g=lambda x: 0.5 * jnp.sum(x ** 2) - g_value(x),
cost_fn=costs.SqEuclidean(),
)
@staticmethod
def _update_logs(
logs: Dict[str, List[Union[float, str]]],
loss_f: jnp.ndarray,
loss_g: jnp.ndarray,
w_dist: jnp.ndarray,
) -> None:
logs["loss_f"].append(float(loss_f))
logs["loss_g"].append(float(loss_g))
logs["w_dist"].append(float(w_dist))