# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from types import MappingProxyType
from typing import Any, Mapping, Optional, Tuple, Type, Union
import jax
import jax.numpy as jnp
from ott import utils
from ott.geometry import costs, geometry, pointcloud, segment
from ott.problems.linear import linear_problem, potentials
from ott.solvers import linear
from ott.solvers.linear import acceleration, sinkhorn, sinkhorn_lr
__all__ = [
"sinkhorn_divergence", "segment_sinkhorn_divergence",
"SinkhornDivergenceOutput"
]
Potentials = Tuple[jnp.ndarray, jnp.ndarray]
Factors = Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
[docs]
@utils.register_pytree_node
class SinkhornDivergenceOutput: # noqa: D101
r"""Holds the outputs of a call to :func:`sinkhorn_divergence`.
Objects of this class contain both solutions and problem definition of a
two or three regularized OT problem instantiated when computing a Sinkhorn
divergence between two probability distributions.
Args:
divergence: value of the Sinkhorn divergence
geoms: three geometries describing the Sinkhorn divergence, of respective
sizes ``[n, m], [n, n], [m, m]`` if their cost or kernel matrices where
instantiated.
a: first ``[n,]`` vector of marginal weights.
b: second ``[m,]`` vector of marginal weights.
potentials: three pairs of dual potential vectors, of sizes
``[n,], [m,]``, ``[n,], [n,]``, ``[m,], [m,]``, returned when the call
to the :func:`~ott.solvers.linear.solve` solver to compute the divergence
relies on a vanilla :class:`~ott.solver.linear.sinkhorn.Sinkhorn` solver.
factors: three triplets of matrices, of sizes
``([n, rank], [m, rank], [rank,])``, ``([n, rank], [n, rank], [rank,])``
and ``([m, rank], [m, rank], [rank,])``, returned when the call
to the :func:`~ott.solvers.linear.solve` solver to compute the divergence
relies on a low-rank :class:`~ott.solver.linear.sinkhorn_lr.LRSinkhorn`
solver.
converged: triplet of booleans indicating the convergence of each of the
three problems run to compute the divergence.
n_iters: number of iterations keeping track of compute effort needed to
complete each of the three terms in the divergence.
"""
divergence: float
geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry]
a: jnp.ndarray
b: jnp.ndarray
potentials: Optional[Tuple[Potentials, Potentials, Potentials]]
factors: Optional[Tuple[Factors, Factors, Factors]]
errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray]]
converged: Tuple[bool, bool, bool]
n_iters: Tuple[int, int, int]
[docs]
def to_dual_potentials(self) -> "potentials.EntropicPotentials":
"""Return dual potential functions, :cite:`pooladian:22`.
Using vectors stored in ``potentials``, instantiate a
:class:`~ott.problems.linear.potentials.EntropicPotentials` object that will
provide approximations to optimal dual potential functions for the dual
OT problem defined for the geometry stored in ``geoms[0]``. These correspond
to Equation 8 in :cite:`pooladian:22`.
"""
assert not self.is_low_rank, \
"Dual potentials not available: divergence computed with low-rank solver."
geom_xy, *_ = self.geoms
prob_xy = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b)
(f_xy, g_xy), (f_x, _), (_, g_y) = self.potentials
return potentials.EntropicPotentials(
f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y
)
@property
def is_low_rank(self) -> bool:
"""Whether the output is low-rank."""
return self.factors is not None
def tree_flatten(self): # noqa: D102
return [
self.divergence, self.geoms, self.a, self.b, self.potentials,
self.factors
], {
"errors": self.errors,
"n_iters": self.n_iters,
"converged": self.converged,
}
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, **aux_data)
[docs]
def sinkdiv(
x: jnp.ndarray,
y: jnp.ndarray,
*,
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
**kwargs: Any,
) -> Tuple[jnp.ndarray, SinkhornDivergenceOutput]:
"""Wrapper to get the :term:`Sinkhorn divergence` between two point clouds.
Convenience wrapper around
:meth:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` provided to
compute the :term:`Sinkhorn divergence` between two point clouds compared with
any ground cost :class:`~ott.geometry.costs.CostFn`. See other relevant
arguments in :meth:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`.
Args:
x: Array of input points, of shape `[num_x, feature]`.
y: Array of target points, of shape `[num_y, feature]`.
cost_fn: cost function of interest.
epsilon: entropic regularization.
kwargs: keywords arguments passed on to the generic
:meth:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` method. Of
notable interest are ``a`` and ``b`` weight vectors, ``static_b`` and
``offset_static_b`` which can be used to bypass the computations of the
transport problem between points stored in ``y`` (possibly with weights
``b``) and themselves, and ``solve_kwargs`` to parameterize the linear
OT solver.
Returns:
The Sinkhorn divergence value, and output object detailing computations.
"""
return sinkhorn_divergence(
pointcloud.PointCloud,
x=x,
y=y,
cost_fn=cost_fn,
epsilon=epsilon,
**kwargs
)
[docs]
def sinkhorn_divergence(
geom: Type[geometry.Geometry],
*args: Any,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
solve_kwargs: Mapping[str, Any] = MappingProxyType({}),
static_b: bool = False,
offset_static_b: Optional[float] = None,
share_epsilon: bool = True,
symmetric_sinkhorn: bool = True,
**kwargs: Any,
) -> Tuple[jnp.ndarray, SinkhornDivergenceOutput]:
r"""Compute :term:`Sinkhorn divergence` between two measures.
The :term:`Sinkhorn divergence` is computed between two measures :math:`\mu`
and :math:`\nu` by specifying three :class:`~ott.geometry.Geometry` objects,
each describing pairwise costs within points in :math:`\mu,\nu`,
:math:`\mu,\mu`, and :math:`\nu,\nu`.
This implementation proposes the most general interface, to generate those
three geometries by specifying first the type of
:class:`~ott.geometry.Geometry` that is used to compare
them, followed by the arguments used to generate these three
:class:`~ott.geometry.Geometry` instances through its corresponding
:meth:`~ott.geometry.geometry.Geometry.prepare_divergences` method.
Args:
geom: Type of the geometry.
args: Positional arguments to
:meth:`~ott.geometry.geometry.Geometry.prepare_divergences` that are
specific to each geometry.
a: the weight of each input point.
b: the weight of each target point.
solve_kwargs: keywords arguments for
:func:`~ott.solvers.linear.solve` that is called either twice
if ``static_b == True`` or three times when ``static_b == False``.
static_b: if :obj:`True`, divergence of the second measure
(with weights ``b``) to itself is **not** recomputed.
offset_static_b: when ``static_b`` is :obj:`True`, use that value to offset
computation. Useful when the value of the divergence of the second measure
to itself is precomputed and not expected to change.
share_epsilon: if True, enforces that the same epsilon regularizer is shared
for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon
will be by default that used when comparing x to y (contained in the first
geometry). This flag is set to True by default, because in the default
setting, the epsilon regularization is a function of the std of the
entries in the cost matrix.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: keywords arguments to the generic class. This is specific to each
geometry.
Returns:
The Sinkhorn divergence value, and output object detailing computations.
"""
geoms = geom.prepare_divergences(*args, static_b=static_b, **kwargs)
geom_xy, geom_x, geom_y, *_ = geoms + (None,) * 3
num_a, num_b = geom_xy.shape
if share_epsilon:
if isinstance(geom_x, geometry.Geometry):
geom_x = geom_x.copy_epsilon(geom_xy)
if isinstance(geom_y, geometry.Geometry):
geom_y = geom_y.copy_epsilon(geom_xy)
a = jnp.ones(num_a) / num_a if a is None else a
b = jnp.ones(num_b) / num_b if b is None else b
out = _sinkhorn_divergence(
geom_xy,
geom_x,
geom_y,
a=a,
b=b,
symmetric_sinkhorn=symmetric_sinkhorn,
offset_yy=offset_static_b,
**solve_kwargs
)
return out.divergence, out
def _sinkhorn_divergence(
geometry_xy: geometry.Geometry,
geometry_xx: geometry.Geometry,
geometry_yy: Optional[geometry.Geometry],
a: jnp.ndarray,
b: jnp.ndarray,
symmetric_sinkhorn: bool,
offset_yy: Optional[float],
**kwargs: Any,
) -> SinkhornDivergenceOutput:
"""Compute the (unbalanced) Sinkhorn divergence for the wrapper function.
This definition includes a correction depending on the total masses of each
measure, as defined in :cite:`sejourne:19`, eq. 15, and is extended to also
accept :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solvers, as
advocated in :cite:`scetbon:22b`.
Args:
geometry_xy: a Cost object able to apply kernels with a certain epsilon,
between the views X and Y.
geometry_xx: a Cost object able to apply kernels with a certain epsilon,
between elements of the view X.
geometry_yy: a Cost object able to apply kernels with a certain epsilon,
between elements of the view Y.
a: jnp.ndarray<float>[n]: the weight of each input point. The sum of
all elements of ``b`` must match that of ``a`` to converge.
b: jnp.ndarray<float>[m]: the weight of each target point. The sum of
all elements of ``b`` must match that of ``a`` to converge.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
offset_yy: when available, regularized OT cost precomputed on
``geometry_yy`` cost when transporting weight vector ``b`` onto itself.
kwargs: Keyword arguments to :func:`~ott.solvers.linear.solve`.
Returns:
The output object
"""
kwargs_symmetric = kwargs.copy()
is_low_rank = kwargs.get("rank", -1) > 0
if symmetric_sinkhorn and not is_low_rank:
# When computing a Sinkhorn divergence, the (x,y) terms and (x,x) / (y,y)
# terms are computed independently. The user might want to pass some
# kwargs to parameterize the solver's behavior, but those should
# only apply to the (x,y) part.
#
# When using the Sinkhorn solver, for the (x,x) / (y,y) part, we fall back
# on a simpler choice (parallel_dual_updates + momentum 0.5) that is known
# to work well in such settings.
#
# Since symmetric terms are computed assuming a = b, the linear systems
# arising in implicit differentiation (if used) of the potentials computed
# for the symmetric parts should be marked as symmetric.
kwargs_symmetric.update(
parallel_dual_updates=True,
momentum=acceleration.Momentum(start=0, value=0.5),
anderson=None,
)
implicit_diff = kwargs.get("implicit_diff")
if implicit_diff is not None:
kwargs_symmetric["implicit_diff"] = implicit_diff.replace(symmetric=True)
out_xy = linear.solve(geometry_xy, a=a, b=b, **kwargs)
out_xx = linear.solve(geometry_xx, a=a, b=a, **kwargs_symmetric)
if geometry_yy is None:
# Create dummy output, corresponds to scenario where static_b is True.
out_yy = _empty_output(is_low_rank, offset_yy)
else:
out_yy = linear.solve(geometry_yy, a=b, b=b, **kwargs_symmetric)
eps = jax.lax.stop_gradient(geometry_xy.epsilon)
div = (
out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
0.5 * eps * (jnp.sum(a) - jnp.sum(b)) ** 2
)
if is_low_rank:
factors = tuple((out.q, out.r, out.g) for out in (out_xy, out_xx, out_yy))
pots = None
else:
pots = tuple((out.f, out.g) for out in (out_xy, out_xx, out_yy))
factors = None
return SinkhornDivergenceOutput(
divergence=div,
geoms=(geometry_xy, geometry_xx, geometry_yy),
a=a,
b=b,
potentials=pots,
factors=factors,
errors=(out_xy.errors, out_xx.errors, out_yy.errors),
converged=(out_xy.converged, out_xx.converged, out_yy.converged),
n_iters=(out_xy.n_iters, out_xx.n_iters, out_yy.n_iters),
)
[docs]
def segment_sinkhorn_divergence(
x: jnp.ndarray,
y: jnp.ndarray,
num_segments: Optional[int] = None,
max_measure_size: Optional[int] = None,
cost_fn: Optional[costs.CostFn] = None,
segment_ids_x: Optional[jnp.ndarray] = None,
segment_ids_y: Optional[jnp.ndarray] = None,
indices_are_sorted: bool = False,
num_per_segment_x: Optional[Tuple[int, ...]] = None,
num_per_segment_y: Optional[Tuple[int, ...]] = None,
weights_x: Optional[jnp.ndarray] = None,
weights_y: Optional[jnp.ndarray] = None,
solve_kwargs: Mapping[str, Any] = MappingProxyType({}),
static_b: bool = False,
share_epsilon: bool = True,
symmetric_sinkhorn: bool = False,
**kwargs: Any
) -> jnp.ndarray:
"""Compute Sinkhorn divergence between subsets of vectors in `x` and `y`.
Helper function designed to compute Sinkhorn divergences between several point
clouds of varying size, in parallel, using padding for efficiency.
In practice, The inputs `x` and `y` (and their weight vectors `weights_x` and
`weights_y`) are assumed to be large weighted point clouds, that describe
points taken from multiple measures. To extract several subsets of points, we
provide two interfaces. The first interface assumes that a vector of id's is
passed, describing for each point of `x` (resp. `y`) to which measure the
point belongs to. The second interface assumes that `x` and `y` were simply
formed by concatenating several measures contiguously, and that only indices
that segment these groups are needed to recover them.
For both interfaces, both `x` and `y` should contain the same total number of
segments. Each segment will be padded as necessary, all segments rearranged as
a tensor, and `vmap` used to evaluate Sinkhorn divergences in parallel.
Args:
x: Array of input points, of shape `[num_x, feature]`.
Multiple segments are held in this single array.
y: Array of target points, of shape `[num_y, feature]`.
num_segments: Number of segments contained in `x` and `y`.
Providing this is required for JIT compilation to work,
see also :func:`~ott.geometry.segment.segment_point_cloud`.
max_measure_size: Total size of measures after padding. Should ideally be
set to an upper bound on points clouds processed with the segment
interface. Should also be smaller than total length of `x` or `y`.
Providing this is required for JIT compilation to work.
cost_fn: Cost function,
defaults to :class:`~ott.geometry.costs.SqEuclidean`.
segment_ids_x: **1st interface** The segment ID for which each row of `x`
belongs. This is a similar interface to :func:`jax.ops.segment_sum`.
segment_ids_y: **1st interface** The segment ID for which each row of `y`
belongs.
indices_are_sorted: **1st interface** Whether `segment_ids_x` and
`segment_ids_y` are sorted.
num_per_segment_x: **2nd interface** Number of points in each segment in
`x`. For example, [100, 20, 30] would imply that `x` is segmented into
three arrays of length `[100]`, `[20]`, and `[30]` respectively.
num_per_segment_y: **2nd interface** Number of points in each segment in
`y`.
weights_x: Weights of each input points, arranged in the same segmented
order as `x`.
weights_y: Weights of each input points, arranged in the same segmented
order as `y`.
solve_kwargs: Optionally a dict containing the keywords arguments for
calls to the `sinkhorn` function, called three times to evaluate for each
segment the Sinkhorn regularized OT cost between `x`/`y`, `x`/`x`, and
`y`/`y` (except when `static_b` is `True`, in which case `y`/`y` is not
evaluated)
static_b: if True, divergence of measure b against itself is NOT computed
share_epsilon: if True, enforces that the same epsilon regularizer is shared
for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon
will be by default that used when comparing x to y (contained in the first
geometry). This flag is set to True by default, because in the default
setting, the epsilon regularization is a function of the mean of the cost
matrix.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: keywords arguments passed to form
:class:`~ott.geometry.pointcloud.PointCloud` geometry objects from the
subsets of points and masses selected in `x` and `y`, this could be for
instance entropy regularization float, scheduler or normalization.
Returns:
An array of Sinkhorn divergence values for each segment.
"""
# instantiate padding vector
dim = x.shape[1]
if cost_fn is None:
# default padder
padding_vector = costs.CostFn._padder(dim=dim)
else:
padding_vector = cost_fn._padder(dim=dim)
def eval_fn(
padded_x: jnp.ndarray,
padded_y: jnp.ndarray,
padded_weight_x: jnp.ndarray,
padded_weight_y: jnp.ndarray,
) -> float:
div, _ = sinkhorn_divergence(
pointcloud.PointCloud,
padded_x,
padded_y,
a=padded_weight_x,
b=padded_weight_y,
solve_kwargs=solve_kwargs,
static_b=static_b,
share_epsilon=share_epsilon,
symmetric_sinkhorn=symmetric_sinkhorn,
cost_fn=cost_fn,
**kwargs,
)
return div
return segment._segment_interface(
x,
y,
eval_fn,
num_segments=num_segments,
max_measure_size=max_measure_size,
segment_ids_x=segment_ids_x,
segment_ids_y=segment_ids_y,
indices_are_sorted=indices_are_sorted,
num_per_segment_x=num_per_segment_x,
num_per_segment_y=num_per_segment_y,
weights_x=weights_x,
weights_y=weights_y,
padding_vector=padding_vector
)
def _empty_output(
is_low_rank: bool,
offset_yy: Optional[float] = None
) -> Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]:
if is_low_rank:
return sinkhorn_lr.LRSinkhornOutput(
q=None,
r=None,
g=None,
ot_prob=None,
epsilon=None,
inner_iterations=0,
converged=True,
costs=jnp.array([-jnp.inf]),
errors=jnp.array([-jnp.inf]),
reg_ot_cost=0.0 if offset_yy is None else offset_yy,
)
return sinkhorn.SinkhornOutput(
potentials=(None, None),
errors=jnp.array([-jnp.inf]),
reg_ot_cost=0.0 if offset_yy is None else offset_yy,
threshold=0.0,
inner_iterations=0,
)