# Source code for ott.solvers.linear.discrete_barycenter

```# Copyright OTT-JAX
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import functools
from typing import NamedTuple, Optional, Sequence

import jax
import jax.numpy as jnp

from ott.geometry import geometry
from ott.math import fixed_point_loop
from ott.problems.linear import barycenter_problem
from ott.solvers.linear import sinkhorn

__all__ = ["SinkhornBarycenterOutput", "FixedBarycenter"]

[docs]class SinkhornBarycenterOutput(NamedTuple):  # noqa: D101
f: jnp.ndarray
g: jnp.ndarray
histogram: jnp.ndarray
errors: jnp.ndarray

[docs]@jax.tree_util.register_pytree_node_class
class FixedBarycenter:
"""A Wasserstein barycenter solver for histograms on a common geometry.

This solver uses a variant of the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm proposed in
:cite:`janati:20a` to compute the barycenter of various measures supported on
the same (common to all) geometry. The geometry is assumed to be either
symmetric, or to describe costs between a set of points and another. In that
case all reference measures have support on the first measure, whereas the
barycenter is supported on the second.

Args:
threshold: convergence threshold. The algorithm stops when the marginal
violations of all transport plans computed for that barycenter go below
that threshold.
norm_error: norm used to compute marginal deviation.
inner_iterations: number of iterations run before recomputing errors.
min_iterations: number of iterations run without checking whether
termination criterion is true.
max_iterations: maximal number of iterations.
lse_mode: sets computations in kernel (``False``) or log-sum-exp mode.
debiased: uses debiasing correction to avoid blur due to entropic
regularization.
"""

def __init__(
self,
threshold: float = 1e-2,
norm_error: int = 1,
inner_iterations: float = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
lse_mode: bool = True,
debiased: bool = False
):
self.threshold = threshold
self.norm_error = norm_error
self.inner_iterations = inner_iterations
self.min_iterations = min_iterations
self.max_iterations = max_iterations
self.lse_mode = lse_mode
self.debiased = debiased

def __call__(
self,
fixed_bp: barycenter_problem.FixedBarycenterProblem,
dual_initialization: Optional[jnp.ndarray] = None,
) -> SinkhornBarycenterOutput:
"""Solve barycenter problem, possibly using clever initialization.

Args:
fixed_bp: Fixed barycenter problem.
dual_initialization: Initial value for the g_v potential/scalings,
one for each of the histograms described in ``fixed_bp``. If ``None``,
use initialization from :cite:`cuturi:15`, eq. 3.6.

Returns:
The barycenter.
"""
geom = fixed_bp.geom
a = fixed_bp.a
num_a, num_b = geom.shape

weights = fixed_bp.weights

if dual_initialization is None:
# initialization strategy from :cite:`cuturi:15`, (3.6).
dual_initialization = geom.apply_cost(a.T, axis=0).T
dual_initialization -= jnp.average(
dual_initialization, weights=weights, axis=0
)[jnp.newaxis, :]

if self.debiased and not geom.is_symmetric:
raise ValueError("Geometry must be symmetric to use debiased option.")
norm_error = (self.norm_error,)
return _discrete_barycenter(
geom, a, weights, dual_initialization, self.threshold, norm_error,
self.inner_iterations, self.min_iterations, self.max_iterations,
self.lse_mode, self.debiased, num_a, num_b
)

def tree_flatten(self):  # noqa: D102
aux = vars(self).copy()
aux.pop("threshold")
return [
self.threshold,
], aux

@classmethod
def tree_unflatten(cls, aux_data, children):  # noqa: D102
return cls(**aux_data, threshold=children[0])

@functools.partial(jax.jit, static_argnums=(5, 6, 7, 8, 9, 10, 11, 12))
def _discrete_barycenter(
geom: geometry.Geometry, a: jnp.ndarray, weights: jnp.ndarray,
dual_initialization: jnp.ndarray, threshold: float,
norm_error: Sequence[int], inner_iterations: int, min_iterations: int,
max_iterations: int, lse_mode: bool, debiased: bool, num_a: int, num_b: int
) -> SinkhornBarycenterOutput:
"""Jit'able function to compute discrete barycenters."""
if lse_mode:
f_u = jnp.zeros_like(a)
g_v = dual_initialization
else:
f_u = jnp.ones_like(a)
g_v = geom.scaling_from_potential(dual_initialization)
# d below is as described in https://arxiv.org/abs/2006.02575. Note that
# d should be considered to be equal to eps log(d) with those notations
# if running in log-sum-exp mode.
d = jnp.zeros((num_b,)) if lse_mode else jnp.ones((num_b,))

if lse_mode:
parallel_update = jax.vmap(
lambda f, g, marginal, iter: geom.
update_potential(f, g, jnp.log(marginal), axis=1),
in_axes=[0, 0, 0, None]
)
parallel_apply = jax.vmap(
lambda f_, g_, eps_: geom.
apply_lse_kernel(f_, g_, eps_, vec=None, axis=0)[0],
in_axes=[0, 0, None]
)
else:
parallel_update = jax.vmap(
lambda f, g, marginal, iter: geom.update_scaling(g, marginal, axis=1),
in_axes=[0, 0, 0, None]
)
parallel_apply = jax.vmap(
lambda f_, g_, eps_: geom.apply_kernel(f_, eps_, axis=0),
in_axes=[0, 0, None]
)

errors_fn = jax.vmap(
functools.partial(
sinkhorn.marginal_error,
geom=geom,
axis=1,
norm_error=norm_error,
lse_mode=lse_mode
),
in_axes=[0, 0, 0]
)
errors = -jnp.ones((max_iterations // inner_iterations + 1, len(norm_error)))

const = (geom, a, weights)

def cond_fn(iteration, const, state):  # pylint: disable=unused-argument
errors = state[0]
return jnp.logical_or(
iteration == 0, errors[iteration // inner_iterations - 1, 0] > threshold
)

def body_fn(iteration, const, state, compute_error):
geom, a, weights = const
errors, d, f_u, g_v = state

eps = geom._epsilon.at(iteration)  # pylint: disable=protected-access
f_u = parallel_update(f_u, g_v, a, iteration)
# kernel_f_u stands for K times potential u if running in scaling mode,
# eps log K exp f / eps in lse mode.
kernel_f_u = parallel_apply(f_u, g_v, eps)
# b below is the running estimate for the barycenter if running in scaling
# mode, eps log b if running in lse mode.
if lse_mode:
b = jnp.average(kernel_f_u, weights=weights, axis=0)
else:
b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0)

if debiased:
if lse_mode:
b += d
d = 0.5 * (
d + geom.update_potential(
jnp.zeros((num_a,)), d, b / eps, iteration=iteration, axis=0
)
)
else:
b *= d
d = jnp.sqrt(d * geom.update_scaling(d, b, iteration=iteration, axis=0))
if lse_mode:
g_v = b[jnp.newaxis, :] - kernel_f_u
else:
g_v = b[jnp.newaxis, :] / kernel_f_u

# re-compute error if compute_error is True, else set to inf.
err = jnp.where(
jnp.logical_and(compute_error, iteration >= min_iterations),
jnp.mean(errors_fn(f_u, g_v, a)), jnp.inf
)

errors = errors.at[iteration // inner_iterations, :].set(err)
return errors, d, f_u, g_v

state = (errors, d, f_u, g_v)

state = fixed_point_loop.fixpoint_iter_backprop(
cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const,
state
)

errors, d, f_u, g_v = state
kernel_f_u = parallel_apply(f_u, g_v, geom.epsilon)
if lse_mode:
b = jnp.average(kernel_f_u, weights=weights, axis=0)
else:
b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0)

if debiased:
if lse_mode:
b += d
else:
b *= d
if lse_mode:
b = jnp.exp(b / geom.epsilon)
return SinkhornBarycenterOutput(f_u, g_v, b, errors)
```