# Source code for ott.tools.sinkhorn_divergence

```# Copyright OTT-JAX
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from types import MappingProxyType
from typing import Any, List, Mapping, NamedTuple, Optional, Tuple, Type

import jax.numpy as jnp

from ott.geometry import costs, geometry, pointcloud, segment
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import acceleration, sinkhorn

__all__ = [
"sinkhorn_divergence", "segment_sinkhorn_divergence",
"SinkhornDivergenceOutput"
]

class SinkhornDivergenceOutput(NamedTuple):  # noqa: D101
divergence: float
potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]]
geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry]
errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray]]
converged: Tuple[bool, bool, bool]
a: jnp.ndarray
b: jnp.ndarray

def to_dual_potentials(self) -> "potentials.EntropicPotentials":
"""Return dual estimators :cite:`pooladian:22`, eq. 8."""
geom_xy, *_ = self.geoms
prob_xy = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b)
(f_xy, g_xy), (f_x, _), (_, g_y) = self.potentials
return potentials.EntropicPotentials(
f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y
)

[docs]def sinkhorn_divergence(
geom: Type[geometry.Geometry],
*args: Any,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}),
static_b: bool = False,
share_epsilon: bool = True,
symmetric_sinkhorn: bool = True,
**kwargs: Any,
) -> SinkhornDivergenceOutput:
"""Compute Sinkhorn divergence defined by a geometry, weights, parameters.

Args:
geom: Type of the geometry.
args: Positional arguments to
:meth:`~ott.geometry.geometry.Geometry.prepare_divergences` that are
specific to each geometry.
a: the weight of each input point. The sum of all elements of `a` must
match that of `b` to converge.
b: the weight of each target point. The sum of all elements of `b` must
match that of `a` to converge.
sinkhorn_kwargs: keywords arguments for
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` that is called twice
if ``static_b = True`` else 3 times.
static_b: if True, divergence of measure `b` against itself is **not**
computed.
share_epsilon: if True, enforces that the same epsilon regularizer is shared
for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon
will be by default that used when comparing x to y (contained in the first
geometry). This flag is set to True by default, because in the default
setting, the epsilon regularization is a function of the mean of the cost
matrix.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: keywords arguments to the generic class. This is specific to each
geometry.

Returns:
Sinkhorn divergence value, three pairs of potentials, three costs.
"""
geoms = geom.prepare_divergences(*args, static_b=static_b, **kwargs)
geom_xy, geom_x, geom_y, *_ = geoms + (None,) * 3
num_a, num_b = geom_xy.shape

if share_epsilon:
if isinstance(geom_x, geometry.Geometry):
geom_x = geom_x.copy_epsilon(geom_xy)
if isinstance(geom_y, geometry.Geometry):
geom_y = geom_y.copy_epsilon(geom_xy)

a = jnp.ones(num_a) / num_a if a is None else a
b = jnp.ones(num_b) / num_b if b is None else b
return _sinkhorn_divergence(
geom_xy,
geom_x,
geom_y,
a=a,
b=b,
symmetric_sinkhorn=symmetric_sinkhorn,
**sinkhorn_kwargs
)

def _sinkhorn_divergence(
geometry_xy: geometry.Geometry,
geometry_xx: geometry.Geometry,
geometry_yy: Optional[geometry.Geometry],
a: jnp.ndarray,
b: jnp.ndarray,
symmetric_sinkhorn: bool,
**kwargs: Any,
) -> SinkhornDivergenceOutput:
"""Compute the (unbalanced) Sinkhorn divergence for the wrapper function.

This definition includes a correction depending on the total masses of each
measure, as defined in :sejourne:19:, eq. 15.

Args:
geometry_xy: a Cost object able to apply kernels with a certain epsilon,
between the views X and Y.
geometry_xx: a Cost object able to apply kernels with a certain epsilon,
between elements of the view X.
geometry_yy: a Cost object able to apply kernels with a certain epsilon,
between elements of the view Y.
a: jnp.ndarray<float>[n]: the weight of each input point. The sum of
all elements of ``b`` must match that of ``a`` to converge.
b: jnp.ndarray<float>[m]: the weight of each target point. The sum of
all elements of ``b`` must match that of ``a`` to converge.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: Keyword arguments to :func:`~ott.solvers.linear.sinkhorn.Sinkhorn`.

Returns:
SinkhornDivergenceOutput named tuple.
"""
# When computing a Sinkhorn divergence, the (x,y) terms and (x,x) / (y,y)
# terms are computed independently. The user might want to pass some
# sinkhorn_kwargs to parameterize Sinkhorn's behavior, but those should
# only apply to the (x,y) part. For the (x,x) / (y,y) part we fall back
# on a simpler choice (parallel_dual_updates + momentum 0.5) that is known
# to work well in such settings. In the future we might want to give some
# freedom on setting parameters for the (x,x)/(y,y) part.
# Since symmetric terms are computed assuming a = b, the linear systems
# arising in implicit differentiation (if used) of the potentials computed for
# the symmetric parts should be marked as symmetric.
kwargs_symmetric = kwargs.copy()
if symmetric_sinkhorn:
kwargs_symmetric.update(
momentum=acceleration.Momentum(start=0, value=0.5),
anderson=None,
)
implicit_diff = kwargs.get("implicit_diff", None)
if implicit_diff is not None:
kwargs_symmetric["implicit_diff"] = implicit_diff.replace(symmetric=True)

out_xy = sinkhorn.solve(geometry_xy, a, b, **kwargs)
out_xx = sinkhorn.solve(geometry_xx, a, a, **kwargs_symmetric)
if geometry_yy is None:
out_yy = sinkhorn.SinkhornOutput(errors=jnp.array([]), reg_ot_cost=0.0)
else:
out_yy = sinkhorn.solve(geometry_yy, b, b, **kwargs_symmetric)

div = (
out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
)
out = (out_xy, out_xx, out_yy)
return SinkhornDivergenceOutput(
div, tuple([s.f, s.g] for s in out),
(geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
tuple(s.converged for s in out), a, b
)

[docs]def segment_sinkhorn_divergence(
x: jnp.ndarray,
y: jnp.ndarray,
num_segments: Optional[int] = None,
max_measure_size: Optional[int] = None,
cost_fn: Optional[costs.CostFn] = None,
segment_ids_x: Optional[jnp.ndarray] = None,
segment_ids_y: Optional[jnp.ndarray] = None,
indices_are_sorted: bool = False,
num_per_segment_x: Optional[Tuple[int, ...]] = None,
num_per_segment_y: Optional[Tuple[int, ...]] = None,
weights_x: Optional[jnp.ndarray] = None,
weights_y: Optional[jnp.ndarray] = None,
sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}),
static_b: bool = False,
share_epsilon: bool = True,
symmetric_sinkhorn: bool = False,
**kwargs: Any
) -> jnp.ndarray:
"""Compute Sinkhorn divergence between subsets of vectors in `x` and `y`.

Helper function designed to compute Sinkhorn divergences between several point
clouds of varying size, in parallel, using padding for efficiency.
In practice, The inputs `x` and `y` (and their weight vectors `weights_x` and
`weights_y`) are assumed to be large weighted point clouds, that describe
points taken from multiple measures. To extract several subsets of points, we
provide two interfaces. The first interface assumes that a vector of id's is
passed, describing for each point of `x` (resp. `y`) to which measure the
point belongs to. The second interface assumes that `x` and `y` were simply
formed by concatenating several measures contiguously, and that only indices
that segment these groups are needed to recover them.

For both interfaces, both `x` and `y` should contain the same total number of
segments. Each segment will be padded as necessary, all segments rearranged as
a tensor, and `vmap` used to evaluate Sinkhorn divergences in parallel.

Args:
x: Array of input points, of shape `[num_x, feature]`.
Multiple segments are held in this single array.
y: Array of target points, of shape `[num_y, feature]`.
num_segments: Number of segments contained in `x` and `y`.
Providing this is required for JIT compilation to work,
max_measure_size: Total size of measures after padding. Should ideally be
set to an upper bound on points clouds processed with the segment
interface. Should also be smaller than total length of `x` or `y`.
Providing this is required for JIT compilation to work.
cost_fn: Cost function,
defaults to :class:`~ott.geometry.costs.SqEuclidean`.
segment_ids_x: **1st interface** The segment ID for which each row of `x`
belongs. This is a similar interface to :func:`jax.ops.segment_sum`.
segment_ids_y: **1st interface** The segment ID for which each row of `y`
belongs.
indices_are_sorted: **1st interface** Whether `segment_ids_x` and
`segment_ids_y` are sorted.
num_per_segment_x: **2nd interface** Number of points in each segment in
`x`. For example, [100, 20, 30] would imply that `x` is segmented into
three arrays of length ``, ``, and `` respectively.
num_per_segment_y: **2nd interface** Number of points in each segment in
`y`.
weights_x: Weights of each input points, arranged in the same segmented
order as `x`.
weights_y: Weights of each input points, arranged in the same segmented
order as `y`.
sinkhorn_kwargs: Optionally a dict containing the keywords arguments for
calls to the `sinkhorn` function, called three times to evaluate for each
segment the Sinkhorn regularized OT cost between `x`/`y`, `x`/`x`, and
`y`/`y` (except when `static_b` is `True`, in which case `y`/`y` is not
evaluated)
static_b: if True, divergence of measure b against itself is NOT computed
share_epsilon: if True, enforces that the same epsilon regularizer is shared
for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon
will be by default that used when comparing x to y (contained in the first
geometry). This flag is set to True by default, because in the default
setting, the epsilon regularization is a function of the mean of the cost
matrix.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: keywords arguments passed to form
:class:`~ott.geometry.pointcloud.PointCloud` geometry objects from the
subsets of points and masses selected in `x` and `y`, this could be for
instance entropy regularization float, scheduler or normalization.

Returns:
An array of Sinkhorn divergences for each segment.
"""
dim = x.shape
if cost_fn is None:
else:

def eval_fn(
) -> float:
return sinkhorn_divergence(
pointcloud.PointCloud,
sinkhorn_kwargs=sinkhorn_kwargs,
static_b=static_b,
share_epsilon=share_epsilon,
symmetric_sinkhorn=symmetric_sinkhorn,
cost_fn=cost_fn,
**kwargs
).divergence

return segment._segment_interface(
x,
y,
eval_fn,
num_segments=num_segments,
max_measure_size=max_measure_size,
segment_ids_x=segment_ids_x,
segment_ids_y=segment_ids_y,
indices_are_sorted=indices_are_sorted,
num_per_segment_x=num_per_segment_x,
num_per_segment_y=num_per_segment_y,
weights_x=weights_x,
weights_y=weights_y,