# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import (
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
import jax
import jax.numpy as jnp
import optax
from flax import core
from ott.geometry import costs
from ott.problems.linear import potentials
from ott.solvers.nn import conjugate_solvers, models
__all__ = ["W2NeuralDual"]
Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]]
Callback_t = Callable[[int, potentials.DualPotentials], None]
Conj_t = Optional[conjugate_solvers.FenchelConjugateSolver]
[docs]class W2NeuralDual:
r"""Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.
Learn the Wasserstein-2 optimal transport between two measures
:math:`\alpha` and :math:`\beta` in :math:`n`-dimensional Euclidean space,
denoted source and target, respectively. This is achieved by parameterizing
a Kantorovich potential :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}`
associated with the :math:`\alpha` measure with an
:class:`~ott.solvers.nn.models.ICNN`, :class:`~ott.solvers.nn.models.MLP`,
or other :class:`~ott.solvers.nn.models.ModelBase`, where :math:`\nabla f`
transports source to target cells. This potential is learned by optimizing
the dual form associated with the negative inner product cost
.. math::
\text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] -
\mathbb{E}_{y\sim\beta}[f^\star_\theta(y)],
where :math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle`
is the convex conjugate.
:math:`\nabla f^\star` transports from the target
to source cells and provides the inverse optimal
transport map from :math:`\beta` to :math:`\alpha`.
This solver estimates the conjugate :math:`f^\star`
with a neural approximation :math:`g` that is fine-tuned
with :class:`~ott.solvers.nn.conjugate_solvers.FenchelConjugateSolver`,
which is a combination further described in :cite:`amos:23`.
The :class:`~ott.solvers.nn.models.ModelBase` potentials for
``neural_f`` and ``neural_g`` can
1. both provide the values of the potentials :math:`f` and :math:`g`, or
2. one of them can provide the gradient mapping e.g., :math:`\nabla f`
or :math:`\nabla g` where the potential's value can be obtained
via the Fenchel conjugate as discussed in :cite:`amos:23`.
The potential's value or gradient mapping is specified via
:attr:`~ott.solvers.nn.models.ModelBase.is_potential`.
Args:
dim_data: input dimensionality of data required for network init
neural_f: network architecture for potential :math:`f`.
neural_g: network architecture for the conjugate potential
:math:`g\approx f^\star`
optimizer_f: optimizer function for potential :math:`f`
optimizer_g: optimizer function for the conjugate potential :math:`g`
num_train_iters: number of total training iterations
num_inner_iters: number of training iterations of :math:`g` per iteration
of :math:`f`
back_and_forth: alternate between updating the forward and backward
directions. Inspired from :cite:`jacobs:20`
valid_freq: frequency with which model is validated
log_freq: frequency with training and validation are logged
logging: option to return logs
rng: random key used for seeding for network initializations
pos_weights: option to train networks with positive weights or regularizer
beta: regularization parameter when not training with positive weights
conjugate_solver: numerical solver for the Fenchel conjugate.
amortization_loss: amortization loss for the conjugate
:math:`g\approx f^\star`. Options are `'objective'` :cite:`makkuva:20` or
`'regression'` :cite:`amos:23`.
parallel_updates: Update :math:`f` and :math:`g` at the same time
"""
def __init__(
self,
dim_data: int,
neural_f: Optional[models.ModelBase] = None,
neural_g: Optional[models.ModelBase] = None,
optimizer_f: Optional[optax.OptState] = None,
optimizer_g: Optional[optax.OptState] = None,
num_train_iters: int = 20000,
num_inner_iters: int = 1,
back_and_forth: Optional[bool] = None,
valid_freq: int = 1000,
log_freq: int = 1000,
logging: bool = False,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
pos_weights: bool = True,
beta: float = 1.0,
conjugate_solver: Conj_t = conjugate_solvers.DEFAULT_CONJUGATE_SOLVER,
amortization_loss: Literal["objective", "regression"] = "regression",
parallel_updates: bool = True,
):
self.num_train_iters = num_train_iters
self.num_inner_iters = num_inner_iters
self.back_and_forth = back_and_forth
self.valid_freq = valid_freq
self.log_freq = log_freq
self.logging = logging
self.pos_weights = pos_weights
self.beta = beta
self.parallel_updates = parallel_updates
self.conjugate_solver = conjugate_solver
self.amortization_loss = amortization_loss
# set default optimizers
if optimizer_f is None:
optimizer_f = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8)
if optimizer_g is None:
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8)
# set default neural architectures
if neural_f is None:
neural_f = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64])
if neural_g is None:
neural_g = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64])
self.neural_f = neural_f
self.neural_g = neural_g
# set optimizer and networks
self.setup(
rng,
neural_f,
neural_g,
dim_data,
optimizer_f,
optimizer_g,
)
[docs] def setup(
self,
rng: jax.random.PRNGKeyArray,
neural_f: models.ModelBase,
neural_g: models.ModelBase,
dim_data: int,
optimizer_f: optax.OptState,
optimizer_g: optax.OptState,
) -> None:
"""Setup all components required to train the network."""
# split random number generator
rng, rng_f, rng_g = jax.random.split(rng, 3)
# check setting of network architectures
warn_str = f"Setting of ICNN and the positive weights setting of the " \
f"`W2NeuralDual` are not consistent. Proceeding with " \
f"the `W2NeuralDual` setting, with positive weights " \
f"being {self.pos_weights}."
if isinstance(
neural_f, models.ICNN
) and neural_f.pos_weights is not self.pos_weights:
warnings.warn(warn_str, stacklevel=2)
neural_f.pos_weights = self.pos_weights
if isinstance(
neural_g, models.ICNN
) and neural_g.pos_weights is not self.pos_weights:
warnings.warn(warn_str, stacklevel=2)
neural_g.pos_weights = self.pos_weights
self.state_f = neural_f.create_train_state(
rng_f,
optimizer_f,
dim_data,
)
self.state_g = neural_g.create_train_state(
rng_g,
optimizer_g,
dim_data,
)
# default to using back_and_forth with the non-convex models
if self.back_and_forth is None:
self.back_and_forth = isinstance(neural_f, models.MLP)
if self.num_inner_iters == 1 and self.parallel_updates:
self.train_step_parallel = self.get_step_fn(
train=True, to_optimize="both"
)
self.valid_step_parallel = self.get_step_fn(
train=False, to_optimize="both"
)
self.train_fn = self.train_neuraldual_parallel
else:
if self.parallel_updates:
warnings.warn(
"parallel_updates set to True but disabling it "
"because num_inner_iters>1",
stacklevel=2
)
if self.back_and_forth:
raise NotImplementedError(
"back_and_forth not implemented without parallel updates"
)
self.train_step_f = self.get_step_fn(train=True, to_optimize="f")
self.valid_step_f = self.get_step_fn(train=False, to_optimize="f")
self.train_step_g = self.get_step_fn(train=True, to_optimize="g")
self.valid_step_g = self.get_step_fn(train=False, to_optimize="g")
self.train_fn = self.train_neuraldual_alternating
def __call__( # noqa: D102
self,
trainloader_source: Iterator[jnp.ndarray],
trainloader_target: Iterator[jnp.ndarray],
validloader_source: Iterator[jnp.ndarray],
validloader_target: Iterator[jnp.ndarray],
callback: Optional[Callback_t] = None,
) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials,
Train_t]]:
logs = self.train_fn(
trainloader_source,
trainloader_target,
validloader_source,
validloader_target,
callback=callback,
)
res = self.to_dual_potentials()
return (res, logs) if self.logging else res
[docs] def train_neuraldual_parallel(
self,
trainloader_source: Iterator[jnp.ndarray],
trainloader_target: Iterator[jnp.ndarray],
validloader_source: Iterator[jnp.ndarray],
validloader_target: Iterator[jnp.ndarray],
callback: Optional[Callback_t] = None,
) -> Train_t:
"""Training and validation with parallel updates."""
try:
from tqdm.auto import tqdm
except ImportError:
tqdm = lambda _: _
# define dict to contain source and target batch
train_batch, valid_batch = {}, {}
# set logging dictionaries
train_logs = {"loss_f": [], "loss_g": [], "w_dist": [], "directions": []}
valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []}
for step in tqdm(range(self.num_train_iters)):
update_forward = not self.back_and_forth or step % 2 == 0
if update_forward:
train_batch["source"] = jnp.asarray(next(trainloader_source))
train_batch["target"] = jnp.asarray(next(trainloader_target))
(self.state_f, self.state_g, loss, loss_f, loss_g,
w_dist) = self.train_step_parallel(
self.state_f,
self.state_g,
train_batch,
)
else:
train_batch["target"] = jnp.asarray(next(trainloader_source))
train_batch["source"] = jnp.asarray(next(trainloader_target))
(self.state_g, self.state_f, loss, loss_f, loss_g,
w_dist) = self.train_step_parallel(
self.state_g,
self.state_f,
train_batch,
)
if self.logging and step % self.log_freq == 0:
self._update_logs(train_logs, loss_f, loss_g, w_dist)
train_logs["directions"].append(
"forward" if update_forward else "backward"
)
if callback is not None:
_ = callback(step, self.to_dual_potentials())
if not self.pos_weights:
# Only clip the weights of the f network
self.state_f = self.state_f.replace(
params=self._clip_weights_icnn(self.state_f.params)
)
# report the loss on an validation dataset periodically
if step != 0 and step % self.valid_freq == 0:
# get batch
valid_batch["source"] = jnp.asarray(next(validloader_source))
valid_batch["target"] = jnp.asarray(next(validloader_target))
valid_loss_f, valid_loss_g, valid_w_dist = self.valid_step_parallel(
self.state_f,
self.state_g,
valid_batch,
)
if self.logging:
self._update_logs(
valid_logs, valid_loss_f, valid_loss_g, valid_w_dist
)
return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs] def train_neuraldual_alternating(
self,
trainloader_source: Iterator[jnp.ndarray],
trainloader_target: Iterator[jnp.ndarray],
validloader_source: Iterator[jnp.ndarray],
validloader_target: Iterator[jnp.ndarray],
callback: Optional[Callback_t] = None,
) -> Train_t:
"""Training and validation with alternating updates."""
try:
from tqdm.auto import tqdm
except ImportError:
tqdm = lambda _: _
# define dict to contain source and target batch
batch_g, batch_f, valid_batch = {}, {}, {}
# set logging dictionaries
train_logs = {"loss_f": [], "loss_g": [], "w_dist": []}
valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []}
for step in tqdm(range(self.num_train_iters)):
# execute training steps
for _ in range(self.num_inner_iters):
# get train batch for potential g
batch_g["source"] = jnp.asarray(next(trainloader_source))
batch_g["target"] = jnp.asarray(next(trainloader_target))
self.state_g, loss_g, _ = self.train_step_g(
self.state_f, self.state_g, batch_g
)
# get train batch for potential f
batch_f["source"] = jnp.asarray(next(trainloader_source))
batch_f["target"] = jnp.asarray(next(trainloader_target))
self.state_f, loss_f, w_dist = self.train_step_f(
self.state_f, self.state_g, batch_f
)
if not self.pos_weights:
# Only clip the weights of the f network
self.state_f = self.state_f.replace(
params=self._clip_weights_icnn(self.state_f.params)
)
if callback is not None:
callback(step, self.to_dual_potentials())
if self.logging and step % self.log_freq == 0:
self._update_logs(train_logs, loss_f, loss_g, w_dist)
# report the loss on validation dataset periodically
if step != 0 and step % self.valid_freq == 0:
# get batch
valid_batch["source"] = jnp.asarray(next(validloader_source))
valid_batch["target"] = jnp.asarray(next(validloader_target))
valid_loss_f, _ = self.valid_step_f(
self.state_f, self.state_g, valid_batch
)
valid_loss_g, valid_w_dist = self.valid_step_g(
self.state_f, self.state_g, valid_batch
)
if self.logging:
self._update_logs(
valid_logs, valid_loss_f, valid_loss_g, valid_w_dist
)
return {"train_logs": train_logs, "valid_logs": valid_logs}
[docs] def get_step_fn(
self, train: bool, to_optimize: Literal["f", "g", "parallel", "both"]
):
"""Create a parallel training and evaluation function."""
def loss_fn(params_f, params_g, f_value, g_value, g_gradient, batch):
"""Loss function for both potentials."""
# get two distributions
source, target = batch["source"], batch["target"]
init_source_hat = g_gradient(params_g)(target)
def g_value_partial(y: jnp.ndarray) -> jnp.ndarray:
"""Lazy way of evaluating g if f's computation needs it."""
return g_value(params_g)(y)
f_value_partial = f_value(params_f, g_value_partial)
if self.conjugate_solver is not None:
finetune_source_hat = lambda y, x_init: self.conjugate_solver.solve(
f_value_partial, y, x_init=x_init
).grad
finetune_source_hat = jax.vmap(finetune_source_hat)
source_hat_detach = jax.lax.stop_gradient(
finetune_source_hat(target, init_source_hat)
)
else:
source_hat_detach = init_source_hat
batch_dot = jax.vmap(jnp.dot)
f_source = f_value_partial(source)
f_star_target = batch_dot(source_hat_detach,
target) - f_value_partial(source_hat_detach)
dual_source = f_source.mean()
dual_target = f_star_target.mean()
dual_loss = dual_source + dual_target
if self.amortization_loss == "regression":
amor_loss = ((init_source_hat - source_hat_detach) ** 2).mean()
elif self.amortization_loss == "objective":
f_value_parameters_detached = f_value(
jax.lax.stop_gradient(params_f), g_value_partial
)
amor_loss = (
f_value_parameters_detached(init_source_hat) -
batch_dot(init_source_hat, target)
).mean()
else:
raise ValueError("Amortization loss has been misspecified.")
if to_optimize == "both":
loss = dual_loss + amor_loss
elif to_optimize == "f":
loss = dual_loss
elif to_optimize == "g":
loss = amor_loss
else:
raise ValueError(
f"Optimization target {to_optimize} has been misspecified."
)
if self.pos_weights:
# Penalize the weights of both networks, even though one
# of them will be exactly clipped.
# Having both here is necessary in case this is being called with
# the potentials reversed with the back_and_forth.
loss += self.beta * self._penalize_weights_icnn(params_f) + \
self.beta * self._penalize_weights_icnn(params_g)
# compute Wasserstein-2 distance
C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \
jnp.mean(jnp.sum(target ** 2, axis=-1))
W2_dist = C - 2. * (f_source.mean() + f_star_target.mean())
return loss, (dual_loss, amor_loss, W2_dist)
@jax.jit
def step_fn(state_f, state_g, batch):
"""Step function of either training or validation."""
grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True)
if train:
# compute loss and gradients
(loss, (loss_f, loss_g, W2_dist)), (grads_f, grads_g) = grad_fn(
state_f.params,
state_g.params,
state_f.potential_value_fn,
state_g.potential_value_fn,
state_g.potential_gradient_fn,
batch,
)
# update state
if to_optimize == "both":
return (
state_f.apply_gradients(grads=grads_f),
state_g.apply_gradients(grads=grads_g), loss, loss_f, loss_g,
W2_dist
)
if to_optimize == "f":
return state_f.apply_gradients(grads=grads_f), loss_f, W2_dist
if to_optimize == "g":
return state_g.apply_gradients(grads=grads_g), loss_g, W2_dist
raise ValueError("Optimization target has been misspecified.")
# compute loss and gradients
(loss, (loss_f, loss_g, W2_dist)), _ = grad_fn(
state_f.params,
state_g.params,
state_f.potential_value_fn,
state_g.potential_value_fn,
state_g.potential_gradient_fn,
batch,
)
# do not update state
if to_optimize == "both":
return loss_f, loss_g, W2_dist
if to_optimize == "f":
return loss_f, W2_dist
if to_optimize == "g":
return loss_g, W2_dist
raise ValueError("Optimization target has been misspecified.")
return step_fn
[docs] def to_dual_potentials(
self, finetune_g: bool = True
) -> potentials.DualPotentials:
r"""Return the Kantorovich dual potentials from the trained potentials.
Args:
finetune_g: Run the conjugate solver to fine-tune the prediction.
Returns:
A dual potential object
"""
f_value = self.state_f.potential_value_fn(self.state_f.params)
g_value_prediction = self.state_g.potential_value_fn(
self.state_g.params, f_value
)
def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray:
x_hat = jax.grad(g_value_prediction)(y)
grad_g_y = jax.lax.stop_gradient(
self.conjugate_solver.solve(f_value, y, x_init=x_hat).grad
)
return -f_value(grad_g_y) + jnp.dot(grad_g_y, y)
return potentials.DualPotentials(
f=f_value,
g=g_value_prediction if not finetune_g or self.conjugate_solver is None
else g_value_finetuned,
cost_fn=costs.SqEuclidean(),
corr=True
)
@staticmethod
def _clip_weights_icnn(params):
params = params.unfreeze()
for k in params:
if k.startswith("w_z"):
params[k]["kernel"] = jnp.clip(params[k]["kernel"], a_min=0)
return core.freeze(params)
@staticmethod
def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float:
penalty = 0.0
for k, param in params.items():
if k.startswith("w_z"):
penalty += jnp.linalg.norm(jax.nn.relu(-param["kernel"]))
return penalty
@staticmethod
def _update_logs(
logs: Dict[str, List[Union[float, str]]],
loss_f: jnp.ndarray,
loss_g: jnp.ndarray,
w_dist: jnp.ndarray,
) -> None:
logs["loss_f"].append(float(loss_f))
logs["loss_g"].append(float(loss_g))
logs["w_dist"].append(float(w_dist))