# Source code for ott.solvers.nn.neuraldual

# Copyright OTT-JAX
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import warnings
from typing import (
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)

import jax
import jax.numpy as jnp
import optax
from flax import core

from ott.geometry import costs
from ott.problems.linear import potentials
from ott.solvers.nn import conjugate_solvers, models

__all__ = ["W2NeuralDual"]

Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]]
Callback_t = Callable[[int, potentials.DualPotentials], None]
Conj_t = Optional[conjugate_solvers.FenchelConjugateSolver]

[docs]class W2NeuralDual:
r"""Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces.

Learn the Wasserstein-2 optimal transport between two measures
:math:\alpha and :math:\beta in :math:n-dimensional Euclidean space,
denoted source and target, respectively. This is achieved by parameterizing
a Kantorovich potential :math:f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}
associated with the :math:\alpha measure with an
:class:~ott.solvers.nn.models.ICNN, :class:~ott.solvers.nn.models.MLP,
or other :class:~ott.solvers.nn.models.ModelBase, where :math:\nabla f
transports source to target cells. This potential is learned by optimizing
the dual form associated with the negative inner product cost

.. math::

\text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] -
\mathbb{E}_{y\sim\beta}[f^\star_\theta(y)],

where :math:f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle
is the convex conjugate.
:math:\nabla f^\star transports from the target
to source cells and provides the inverse optimal
transport map from :math:\beta to :math:\alpha.
This solver estimates the conjugate :math:f^\star
with a neural approximation :math:g that is fine-tuned
with :class:~ott.solvers.nn.conjugate_solvers.FenchelConjugateSolver,
which is a combination further described in :cite:amos:23.

The :class:~ott.solvers.nn.models.ModelBase potentials for
neural_f and neural_g can

1. both provide the values of the potentials :math:f and :math:g, or
2. one of them can provide the gradient mapping e.g., :math:\nabla f
or :math:\nabla g where the potential's value can be obtained
via the Fenchel conjugate as discussed in :cite:amos:23.

The potential's value or gradient mapping is specified via
:attr:~ott.solvers.nn.models.ModelBase.is_potential.

Args:
dim_data: input dimensionality of data required for network init
neural_f: network architecture for potential :math:f.
neural_g: network architecture for the conjugate potential
:math:g\approx f^\star
optimizer_f: optimizer function for potential :math:f
optimizer_g: optimizer function for the conjugate potential :math:g
num_train_iters: number of total training iterations
num_inner_iters: number of training iterations of :math:g per iteration
of :math:f
back_and_forth: alternate between updating the forward and backward
directions. Inspired from :cite:jacobs:20
valid_freq: frequency with which model is validated
log_freq: frequency with training and validation are logged
logging: option to return logs
rng: random key used for seeding for network initializations
pos_weights: option to train networks with positive weights or regularizer
beta: regularization parameter when not training with positive weights
conjugate_solver: numerical solver for the Fenchel conjugate.
amortization_loss: amortization loss for the conjugate
:math:g\approx f^\star. Options are 'objective' :cite:makkuva:20 or
'regression' :cite:amos:23.
parallel_updates: Update :math:f and :math:g at the same time
"""

def __init__(
self,
dim_data: int,
neural_f: Optional[models.ModelBase] = None,
neural_g: Optional[models.ModelBase] = None,
optimizer_f: Optional[optax.OptState] = None,
optimizer_g: Optional[optax.OptState] = None,
num_train_iters: int = 20000,
num_inner_iters: int = 1,
back_and_forth: Optional[bool] = None,
valid_freq: int = 1000,
log_freq: int = 1000,
logging: bool = False,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
pos_weights: bool = True,
beta: float = 1.0,
conjugate_solver: Conj_t = conjugate_solvers.DEFAULT_CONJUGATE_SOLVER,
amortization_loss: Literal["objective", "regression"] = "regression",
):
self.num_train_iters = num_train_iters
self.num_inner_iters = num_inner_iters
self.back_and_forth = back_and_forth
self.valid_freq = valid_freq
self.log_freq = log_freq
self.logging = logging
self.pos_weights = pos_weights
self.beta = beta
self.conjugate_solver = conjugate_solver
self.amortization_loss = amortization_loss

# set default optimizers
if optimizer_f is None:
optimizer_f = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8)
if optimizer_g is None:
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8)

# set default neural architectures
if neural_f is None:
neural_f = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64])
if neural_g is None:
neural_g = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64])
self.neural_f = neural_f
self.neural_g = neural_g

# set optimizer and networks
self.setup(
rng,
neural_f,
neural_g,
dim_data,
optimizer_f,
optimizer_g,
)

[docs]  def setup(
self,
rng: jax.random.PRNGKeyArray,
neural_f: models.ModelBase,
neural_g: models.ModelBase,
dim_data: int,
optimizer_f: optax.OptState,
optimizer_g: optax.OptState,
) -> None:
"""Setup all components required to train the network."""
# split random number generator
rng, rng_f, rng_g = jax.random.split(rng, 3)

# check setting of network architectures
warn_str = f"Setting of ICNN and the positive weights setting of the " \
f"W2NeuralDual are not consistent. Proceeding with " \
f"the W2NeuralDual setting, with positive weights " \
f"being {self.pos_weights}."
if isinstance(
neural_f, models.ICNN
) and neural_f.pos_weights is not self.pos_weights:
warnings.warn(warn_str, stacklevel=2)
neural_f.pos_weights = self.pos_weights

if isinstance(
neural_g, models.ICNN
) and neural_g.pos_weights is not self.pos_weights:
warnings.warn(warn_str, stacklevel=2)
neural_g.pos_weights = self.pos_weights

self.state_f = neural_f.create_train_state(
rng_f,
optimizer_f,
dim_data,
)
self.state_g = neural_g.create_train_state(
rng_g,
optimizer_g,
dim_data,
)

# default to using back_and_forth with the non-convex models
if self.back_and_forth is None:
self.back_and_forth = isinstance(neural_f, models.MLP)

if self.num_inner_iters == 1 and self.parallel_updates:
self.train_step_parallel = self.get_step_fn(
train=True, to_optimize="both"
)
self.valid_step_parallel = self.get_step_fn(
train=False, to_optimize="both"
)
self.train_fn = self.train_neuraldual_parallel
else:
warnings.warn(
"parallel_updates set to True but disabling it "
"because num_inner_iters>1",
stacklevel=2
)
if self.back_and_forth:
raise NotImplementedError(
"back_and_forth not implemented without parallel updates"
)
self.train_step_f = self.get_step_fn(train=True, to_optimize="f")
self.valid_step_f = self.get_step_fn(train=False, to_optimize="f")
self.train_step_g = self.get_step_fn(train=True, to_optimize="g")
self.valid_step_g = self.get_step_fn(train=False, to_optimize="g")
self.train_fn = self.train_neuraldual_alternating

def __call__(  # noqa: D102
self,
callback: Optional[Callback_t] = None,
) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials,
Train_t]]:
logs = self.train_fn(
callback=callback,
)
res = self.to_dual_potentials()

return (res, logs) if self.logging else res

[docs]  def train_neuraldual_parallel(
self,
callback: Optional[Callback_t] = None,
) -> Train_t:
"""Training and validation with parallel updates."""
try:
from tqdm.auto import tqdm
except ImportError:
tqdm = lambda _: _
# define dict to contain source and target batch
train_batch, valid_batch = {}, {}

# set logging dictionaries
train_logs = {"loss_f": [], "loss_g": [], "w_dist": [], "directions": []}
valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []}

for step in tqdm(range(self.num_train_iters)):
update_forward = not self.back_and_forth or step % 2 == 0
if update_forward:
(self.state_f, self.state_g, loss, loss_f, loss_g,
w_dist) = self.train_step_parallel(
self.state_f,
self.state_g,
train_batch,
)
else:
(self.state_g, self.state_f, loss, loss_f, loss_g,
w_dist) = self.train_step_parallel(
self.state_g,
self.state_f,
train_batch,
)

if self.logging and step % self.log_freq == 0:
self._update_logs(train_logs, loss_f, loss_g, w_dist)
train_logs["directions"].append(
"forward" if update_forward else "backward"
)

if callback is not None:
_ = callback(step, self.to_dual_potentials())

if not self.pos_weights:
# Only clip the weights of the f network
self.state_f = self.state_f.replace(
params=self._clip_weights_icnn(self.state_f.params)
)

# report the loss on an validation dataset periodically
if step != 0 and step % self.valid_freq == 0:
# get batch

valid_loss_f, valid_loss_g, valid_w_dist = self.valid_step_parallel(
self.state_f,
self.state_g,
valid_batch,
)

if self.logging:
self._update_logs(
valid_logs, valid_loss_f, valid_loss_g, valid_w_dist
)

return {"train_logs": train_logs, "valid_logs": valid_logs}

[docs]  def train_neuraldual_alternating(
self,
callback: Optional[Callback_t] = None,
) -> Train_t:
"""Training and validation with alternating updates."""
try:
from tqdm.auto import tqdm
except ImportError:
tqdm = lambda _: _
# define dict to contain source and target batch
batch_g, batch_f, valid_batch = {}, {}, {}

# set logging dictionaries
train_logs = {"loss_f": [], "loss_g": [], "w_dist": []}
valid_logs = {"loss_f": [], "loss_g": [], "w_dist": []}

for step in tqdm(range(self.num_train_iters)):
# execute training steps
for _ in range(self.num_inner_iters):
# get train batch for potential g

self.state_g, loss_g, _ = self.train_step_g(
self.state_f, self.state_g, batch_g
)

# get train batch for potential f

self.state_f, loss_f, w_dist = self.train_step_f(
self.state_f, self.state_g, batch_f
)
if not self.pos_weights:
# Only clip the weights of the f network
self.state_f = self.state_f.replace(
params=self._clip_weights_icnn(self.state_f.params)
)

if callback is not None:
callback(step, self.to_dual_potentials())

if self.logging and step % self.log_freq == 0:
self._update_logs(train_logs, loss_f, loss_g, w_dist)

# report the loss on validation dataset periodically
if step != 0 and step % self.valid_freq == 0:
# get batch

valid_loss_f, _ = self.valid_step_f(
self.state_f, self.state_g, valid_batch
)
valid_loss_g, valid_w_dist = self.valid_step_g(
self.state_f, self.state_g, valid_batch
)

if self.logging:
self._update_logs(
valid_logs, valid_loss_f, valid_loss_g, valid_w_dist
)

return {"train_logs": train_logs, "valid_logs": valid_logs}

[docs]  def get_step_fn(
self, train: bool, to_optimize: Literal["f", "g", "parallel", "both"]
):
"""Create a parallel training and evaluation function."""

def loss_fn(params_f, params_g, f_value, g_value, g_gradient, batch):
"""Loss function for both potentials."""
# get two distributions
source, target = batch["source"], batch["target"]

def g_value_partial(y: jnp.ndarray) -> jnp.ndarray:
"""Lazy way of evaluating g if f's computation needs it."""
return g_value(params_g)(y)

f_value_partial = f_value(params_f, g_value_partial)
if self.conjugate_solver is not None:
finetune_source_hat = lambda y, x_init: self.conjugate_solver.solve(
f_value_partial, y, x_init=x_init
finetune_source_hat = jax.vmap(finetune_source_hat)
finetune_source_hat(target, init_source_hat)
)
else:
source_hat_detach = init_source_hat

batch_dot = jax.vmap(jnp.dot)

f_source = f_value_partial(source)
f_star_target = batch_dot(source_hat_detach,
target) - f_value_partial(source_hat_detach)
dual_source = f_source.mean()
dual_target = f_star_target.mean()
dual_loss = dual_source + dual_target

if self.amortization_loss == "regression":
amor_loss = ((init_source_hat - source_hat_detach) ** 2).mean()
elif self.amortization_loss == "objective":
f_value_parameters_detached = f_value(
)
amor_loss = (
f_value_parameters_detached(init_source_hat) -
batch_dot(init_source_hat, target)
).mean()
else:
raise ValueError("Amortization loss has been misspecified.")

if to_optimize == "both":
loss = dual_loss + amor_loss
elif to_optimize == "f":
loss = dual_loss
elif to_optimize == "g":
loss = amor_loss
else:
raise ValueError(
f"Optimization target {to_optimize} has been misspecified."
)

if self.pos_weights:
# Penalize the weights of both networks, even though one
# of them will be exactly clipped.
# Having both here is necessary in case this is being called with
# the potentials reversed with the back_and_forth.
loss += self.beta * self._penalize_weights_icnn(params_f) + \
self.beta * self._penalize_weights_icnn(params_g)

# compute Wasserstein-2 distance
C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \
jnp.mean(jnp.sum(target ** 2, axis=-1))
W2_dist = C - 2. * (f_source.mean() + f_star_target.mean())

return loss, (dual_loss, amor_loss, W2_dist)

@jax.jit
def step_fn(state_f, state_g, batch):
"""Step function of either training or validation."""
if train:
state_f.params,
state_g.params,
state_f.potential_value_fn,
state_g.potential_value_fn,
batch,
)
# update state
if to_optimize == "both":
return (
W2_dist
)
if to_optimize == "f":
if to_optimize == "g":
raise ValueError("Optimization target has been misspecified.")

(loss, (loss_f, loss_g, W2_dist)), _ = grad_fn(
state_f.params,
state_g.params,
state_f.potential_value_fn,
state_g.potential_value_fn,
batch,
)

# do not update state
if to_optimize == "both":
return loss_f, loss_g, W2_dist
if to_optimize == "f":
return loss_f, W2_dist
if to_optimize == "g":
return loss_g, W2_dist
raise ValueError("Optimization target has been misspecified.")

return step_fn

[docs]  def to_dual_potentials(
self, finetune_g: bool = True
) -> potentials.DualPotentials:
r"""Return the Kantorovich dual potentials from the trained potentials.

Args:
finetune_g: Run the conjugate solver to fine-tune the prediction.

Returns:
A dual potential object
"""
f_value = self.state_f.potential_value_fn(self.state_f.params)
g_value_prediction = self.state_g.potential_value_fn(
self.state_g.params, f_value
)

def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray:
)

return potentials.DualPotentials(
f=f_value,
g=g_value_prediction if not finetune_g or self.conjugate_solver is None
else g_value_finetuned,
cost_fn=costs.SqEuclidean(),
corr=True
)

@staticmethod
def _clip_weights_icnn(params):
params = params.unfreeze()
for k in params:
if k.startswith("w_z"):
params[k]["kernel"] = jnp.clip(params[k]["kernel"], a_min=0)

return core.freeze(params)

@staticmethod
def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float:
penalty = 0.0
for k, param in params.items():
if k.startswith("w_z"):
penalty += jnp.linalg.norm(jax.nn.relu(-param["kernel"]))
return penalty

@staticmethod
def _update_logs(
logs: Dict[str, List[Union[float, str]]],
loss_f: jnp.ndarray,
loss_g: jnp.ndarray,
w_dist: jnp.ndarray,
) -> None:
logs["loss_f"].append(float(loss_f))
logs["loss_g"].append(float(loss_g))
logs["w_dist"].append(float(w_dist))