# Source code for ott.problems.linear.potentials

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
Tuple,
)

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import numpy as np

from ott.problems.linear import linear_problem

if TYPE_CHECKING:
from ott.geometry import costs

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
except ImportError:
mpl = plt = None

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]

[docs]@jtu.register_pytree_node_class
class DualPotentials:
r"""The Kantorovich dual potential functions :math:f and :math:g.

:math:f and :math:g are a pair of functions, candidates for the dual
OT Kantorovich problem, supposedly optimal for a given pair of measures.

Args:
f: The first dual potential function.
g: The second dual potential function.
cost_fn: The cost function used to solve the OT problem.
corr: Whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see, e.g., top right of p.3 in
:cite:makkuva:20)
"""

def __init__(
self,
f: Potential_t,
g: Potential_t,
*,
cost_fn: "costs.CostFn",
corr: bool = False
):
self._f = f
self._g = g
self.cost_fn = cost_fn
self._corr = corr

[docs]  def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
r"""Transport vec according to Brenier formula :cite:brenier:91.

Uses Theorem 1.17 from :cite:santambrogio:15 to compute an OT map when
given the Legendre transform of the dual potentials.

That OT map can be recovered as :math:x- (\nabla h^*)\circ \nabla f(x),
where :math:h^* is the Legendre transform of :math:h. For instance,
in the case :math:h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,,
one has :math:h^*(\cdot) = \|.\|^2 / 4, and therefore
:math:\nabla h^*(\cdot) = 0.5 \cdot\,.

When the dual potentials are solved in correlation form (only in the Sq.
Euclidean distance case), the maps are :math:\nabla g for forward,
:math:\nabla f for backward.

Args:
vec: Points to transport, array of shape [n, d].
forward: Whether to transport the points from source  to the target
distribution or vice-versa.

Returns:
The transported points.
"""
from ott.geometry import costs

vec = jnp.atleast_2d(vec)
if self._corr and isinstance(self.cost_fn, costs.SqEuclidean):
if forward:

[docs]  def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
"""Evaluate 2-Wasserstein distance between samples using dual potentials.

Uses Eq. 5 from :cite:makkuva:20 when given in corr form, direct
estimation by integrating dual function against points when using dual form.

Args:
src: Samples from the source distribution, array of shape [n, d].
tgt: Samples from the target distribution, array of shape [m, d].

Returns:
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)
f = jax.vmap(self.f)

if self._corr:
term1 = -jnp.mean(f(src))
term2 = -jnp.mean(jnp.sum(tgt * grad_g_y, axis=-1) - f(grad_g_y))

C = jnp.mean(jnp.sum(src ** 2, axis=-1))
C += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return 2. * (term1 + term2) + C

g = jax.vmap(self.g)
return jnp.mean(f(src)) + jnp.mean(g(tgt))

@property
def f(self) -> Potential_t:
"""The first dual potential function."""
return self._f

@property
def g(self) -> Potential_t:
"""The second dual potential function."""
return self._g

@property
def _grad_f(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:f."""

@property
def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:g."""

@property
def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
from ott.geometry import costs

assert isinstance(self.cost_fn, costs.TICost), (
"Cost must be a TICost and "
"provide access to Legendre transform of h."
)

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:  # noqa: D102
return [], {
"f": self._f,
"g": self._g,
"cost_fn": self.cost_fn,
"corr": self._corr
}

@classmethod
def tree_unflatten(  # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "DualPotentials":
return cls(*children, **aux_data)

[docs]  def plot_ot_map(
self,
source: jnp.ndarray,
target: jnp.ndarray,
forward: bool = True,
ax: Optional["plt.Axes"] = None,
legend_kwargs: Optional[Dict[str, Any]] = None,
scatter_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["plt.Figure", "plt.Axes"]:
"""Plot data and learned optimal transport map.

Args:
source: samples from the source measure
target: samples from the target measure
forward: use the forward map from the potentials
if True, otherwise use the inverse map
ax: axis to add the plot to
scatter_kwargs: additional kwargs passed into
:meth:~matplotlib.axes.Axes.scatter
legend_kwargs: additional kwargs passed into
:meth:~matplotlib.axes.Axes.legend

Returns:
matplotlib figure and axis with the plots
"""
if mpl is None:
raise RuntimeError("Please install matplotlib first.")

if scatter_kwargs is None:
scatter_kwargs = {"alpha": 0.5}
if legend_kwargs is None:
legend_kwargs = {
"ncol": 3,
"loc": "upper center",
"bbox_to_anchor": (0.5, -0.05),
"edgecolor": "k"
}

if ax is None:
fig = plt.figure(facecolor="white")
else:
fig = ax.get_figure()

# plot the source and target samples
if forward:
label_transport = r"$\nabla f(source)$"
source_color, target_color = "#1A254B", "#A7BED3"
else:
label_transport = r"$\nabla g(target)$"
source_color, target_color = "#A7BED3", "#1A254B"

ax.scatter(
source[:, 0],
source[:, 1],
color=source_color,
label="source",
**scatter_kwargs,
)
ax.scatter(
target[:, 0],
target[:, 1],
color=target_color,
label="target",
**scatter_kwargs,
)

# plot the transported samples
base_samples = source if forward else target
transported_samples = self.transport(base_samples, forward=forward)
ax.scatter(
transported_samples[:, 0],
transported_samples[:, 1],
color="#F2545B",
label=label_transport,
**scatter_kwargs,
)

for i in range(base_samples.shape[0]):
ax.arrow(
base_samples[i, 0],
base_samples[i, 1],
transported_samples[i, 0] - base_samples[i, 0],
transported_samples[i, 1] - base_samples[i, 1],
color=[0.5, 0.5, 1],
alpha=0.3
)

ax.legend(**legend_kwargs)
return fig, ax

[docs]  def plot_potential(
self,
forward: bool = True,
quantile: float = 0.05,
ax: Optional["mpl.axes.Axes"] = None,
x_bounds: Tuple[float, float] = (-6, 6),
y_bounds: Tuple[float, float] = (-6, 6),
num_grid: int = 50,
contourf_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["mpl.figure.Figure", "mpl.axes.Axes"]:
r"""Plot the potential.

Args:
forward: use the forward map from the potentials
if True, otherwise use the inverse map
quantile: quantile to filter the potentials with
ax: axis to add the plot to
x_bounds: x-axis bounds of the plot
:math:(x_{\text{min}}, x_{\text{max}})
y_bounds: y-axis bounds of the plot
:math:(y_{\text{min}}, y_{\text{max}})
num_grid: number of points to discretize the domain into a grid
along each dimension
contourf_kwargs: additional kwargs passed into
:meth:~matplotlib.axes.Axes.contourf

Returns:
matplotlib figure and axis with the plots.
"""
if contourf_kwargs is None:
contourf_kwargs = {}

ax_specified = ax is not None
if not ax_specified:
fig, ax = plt.subplots(figsize=(6, 6), facecolor="white")
else:
fig = ax.get_figure()

x1 = jnp.linspace(*x_bounds, num=num_grid)
x2 = jnp.linspace(*y_bounds, num=num_grid)
X1, X2 = jnp.meshgrid(x1, x2)
X12flat = jnp.hstack((X1.reshape(-1, 1), X2.reshape(-1, 1)))
Zflat = jax.vmap(self.f if forward else self.g)(X12flat)
Zflat = np.asarray(Zflat)
vmin, vmax = np.quantile(Zflat, [quantile, 1. - quantile])
Zflat = Zflat.clip(vmin, vmax)
Z = Zflat.reshape(X1.shape)

CS = ax.contourf(X1, X2, Z, cmap="Blues", **contourf_kwargs)
ax.set_xlim(*x_bounds)
ax.set_ylim(*y_bounds)
fig.colorbar(CS, ax=ax)
if not ax_specified:
fig.tight_layout()
ax.set_title(r"$f$" if forward else r"$g$")
return fig, ax

[docs]@jtu.register_pytree_node_class
class EntropicPotentials(DualPotentials):
"""Dual potential functions from finite samples :cite:pooladian:21.

Args:
f_xy: The first dual potential vector of shape [n,].
g_xy: The second dual potential vector of shape [m,].
prob: Linear problem with :class:~ott.geometry.pointcloud.PointCloud
geometry that was used to compute the dual potentials using, e.g.,
:class:~ott.solvers.linear.sinkhorn.Sinkhorn.
f_xx: The first dual potential vector of shape [n,] used for debiasing
:cite:pooladian:22.
g_yy: The second dual potential vector of shape [m,] used for debiasing.
"""

def __init__(
self,
f_xy: jnp.ndarray,
g_xy: jnp.ndarray,
prob: linear_problem.LinearProblem,
f_xx: Optional[jnp.ndarray] = None,
g_yy: Optional[jnp.ndarray] = None,
):
# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f_xy, g_xy, cost_fn=prob.geom.cost_fn, corr=False)
self._prob = prob
self._f_xx = f_xx
self._g_yy = g_yy

@property
def f(self) -> Potential_t:  # noqa: D102
return self._potential_fn(kind="f")

@property
def g(self) -> Potential_t:  # noqa: D102
return self._potential_fn(kind="g")

def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t:
from ott.geometry import pointcloud

def callback(
x: jnp.ndarray,
*,
potential: jnp.ndarray,
y: jnp.ndarray,
weights: jnp.ndarray,
epsilon: float,
) -> float:
x = jnp.atleast_2d(x)
assert x.shape[-1] == y.shape[-1], (x.shape, y.shape)
geom = pointcloud.PointCloud(x, y, cost_fn=self.cost_fn)
cost = geom.cost_matrix
z = (potential - cost) / epsilon
lse = -epsilon * jsp.special.logsumexp(z, b=weights, axis=-1)
return jnp.squeeze(lse)

assert isinstance(
self._prob.geom, pointcloud.PointCloud
), f"Expected point cloud geometry, found {type(self._prob.geom)}."
x, y = self._prob.geom.x, self._prob.geom.y
a, b = self._prob.a, self._prob.b

if kind == "f":
# When seeking to evaluate 1st potential function,
# the 2nd set of potential values and support should be used,
# see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf
potential, arr, weights = self._g, y, b
else:
potential, arr, weights = self._f, x, a

potential_xy = jax.tree_util.Partial(
callback,
potential=potential,
y=arr,
weights=weights,
epsilon=self.epsilon,
)
if not self.is_debiased:
return potential_xy

ep = EntropicPotentials(self._f_xx, self._g_yy, prob=self._prob)
# switch the order because for kind='f' we require f/x/a in other
# which is accessed when kind='g'
potential_other = ep._potential_fn(kind="g" if kind == "f" else "f")

return lambda x: (potential_xy(x) - potential_other(x))

@property
def is_debiased(self) -> bool:
"""Whether the entropic map is debiased."""
return self._f_xx is not None and self._g_yy is not None

@property
def epsilon(self) -> float:
"""Entropy regularizer."""
return self._prob.geom.epsilon

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:  # noqa: D102
return [self._f, self._g, self._prob, self._f_xx, self._g_yy], {}