# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Several cost/norm functions for relevant vector types."""
import abc
import functools
import math
from typing import Any, Callable, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from ott.math import fixed_point_loop, matrix_square_root
__all__ = [
"PNormP", "SqPNorm", "Euclidean", "SqEuclidean", "Cosine", "Bures",
"UnbalancedBures"
]
[docs]@jax.tree_util.register_pytree_node_class
class CostFn(abc.ABC):
"""A generic cost function, taking two vectors as input.
Cost functions evaluate a function on a pair of inputs. For convenience,
that function is split into two norms -- evaluated on each input separately --
followed by a pairwise cost that involves both inputs, as in:
``c(x,y) = norm(x) + norm(y) + pairwise(x,y)``
If the :attr:`norm` function is not implemented, that value is handled as a 0,
and only :attr:`pairwise` is used.
"""
# no norm function created by default.
norm: Optional[Callable[[jnp.ndarray], Union[float, jnp.ndarray]]] = None
[docs] @abc.abstractmethod
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
pass
[docs] def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
"""Barycentric operator.
Args:
weights: Convex set of weights.
xs: Points.
Returns:
The barycenter of `xs` using `weights` coefficients.
"""
raise NotImplementedError("Barycenter is not yet implemented.")
@classmethod
def _padder(cls, dim: int) -> jnp.ndarray:
"""Create a padding vector of adequate dimension, well-suited to a cost.
Args:
dim: Dimensionality of the data.
Returns:
The padding vector.
"""
return jnp.zeros((1, dim))
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
cost = self.pairwise(x, y)
if self.norm is None:
return cost
return cost + self.norm(x) + self.norm(y)
[docs] def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute matrix of all costs (including norms) for vectors in x / y.
Args:
x: [num_a, d] jnp.ndarray
y: [num_b, d] jnp.ndarray
Returns:
[num_a, num_b] matrix of cost evaluations.
"""
return jax.vmap(lambda x_: jax.vmap(lambda y_: self(x_, y_))(y))(x)
[docs] def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute matrix of all pairwise-costs (no norms) for vectors in x / y.
Args:
x: [num_a, d] jnp.ndarray
y: [num_b, d] jnp.ndarray
Returns:
[num_a, num_b] matrix of pairwise cost evaluations.
"""
return jax.vmap(lambda x_: jax.vmap(lambda y_: self.pairwise(x_, y_))(y))(x)
def tree_flatten(self):
return (), None
@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(*children)
[docs]@jax.tree_util.register_pytree_node_class
class TICost(CostFn):
"""A class for translation invariant (TI) costs.
Such costs are defined using a function :math:`h`, mapping vectors to
real-values, to be used as:
.. math::
c(x,y) = h(z), z := x-y.
If that cost function is used to form an Entropic map using the
:cite:`brenier:91` theorem, then the user should ensure :math:`h` is
strictly convex, as well as provide the Legendre transform of :math:`h`,
whose gradient is necessarily the inverse of the gradient of :math:`h`.
"""
[docs] @abc.abstractmethod
def h(self, z: jnp.ndarray) -> float:
"""TI function acting on difference of :math:`x-y` to output cost."""
[docs] def h_legendre(self, z: jnp.ndarray) -> float:
"""Legendre transform of :func:`h` when it is convex."""
raise NotImplementedError("`h_legendre` not implemented.")
[docs] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute cost as evaluation of :func:`h` on :math:`x-y`."""
return self.h(x - y)
[docs]@jax.tree_util.register_pytree_node_class
class SqPNorm(TICost):
"""Squared p-norm of the difference of two vectors.
For details on the derivation of the Legendre transform of the norm, see e.g.
the reference :cite:`boyd:04`, p.93/94.
Args:
p: Power of the p-norm.
"""
def __init__(self, p: float):
super().__init__()
assert p >= 1.0, "p parameter in sq. p-norm should be >= 1.0"
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf
[docs] def h(self, z: jnp.ndarray) -> float:
return 0.5 * jnp.linalg.norm(z, self.p) ** 2
[docs] def h_legendre(self, z: jnp.ndarray) -> float:
return 0.5 * jnp.linalg.norm(z, self.q) ** 2
def tree_flatten(self):
return (), (self.p,)
@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0])
[docs]@jax.tree_util.register_pytree_node_class
class PNormP(TICost):
"""p-norm to the power p (and divided by p) of the difference of two vectors.
Args:
p: Power of the p-norm, a finite float larger than 1.0.
"""
def __init__(self, p: float):
super().__init__()
assert p >= 1.0, "p parameter in p-norm should be larger than 1.0"
assert p < jnp.inf, "p parameter in p-norm should be finite"
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf
[docs] def h(self, z: jnp.ndarray) -> float:
return jnp.linalg.norm(z, self.p) ** self.p / self.p
[docs] def h_legendre(self, z: jnp.ndarray) -> float:
assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`"
return jnp.linalg.norm(z, self.q) ** self.q / self.q
def tree_flatten(self):
return (), (self.p,)
@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0])
[docs]@jax.tree_util.register_pytree_node_class
class Euclidean(CostFn):
"""Euclidean distance.
Note that the Euclidean distance is not cast as a
:class:`~ott.geometry.costs.TICost`, since this would correspond to :math:`h`
being :func:`jax.numpy.linalg.norm`, whose gradient is not invertible,
because the function is not strictly convex (it is linear on rays).
"""
[docs] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute Euclidean norm."""
return jnp.linalg.norm(x - y)
[docs]@jax.tree_util.register_pytree_node_class
class SqEuclidean(TICost):
"""Squared Euclidean distance."""
[docs] def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:
"""Compute squared Euclidean norm for vector."""
return jnp.sum(x ** 2, axis=-1)
[docs] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute minus twice the dot-product between vectors."""
return -2. * jnp.vdot(x, y)
[docs] def h(self, z: jnp.ndarray) -> float:
return jnp.sum(z ** 2)
[docs] def h_legendre(self, z: jnp.ndarray) -> float:
return 0.25 * jnp.sum(z ** 2)
[docs] def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
"""Output barycenter of vectors when using squared-Euclidean distance."""
return jnp.average(xs, weights=weights, axis=0)
[docs]@jax.tree_util.register_pytree_node_class
class Cosine(CostFn):
"""Cosine distance cost function.
Args:
ridge: Ridge regularization.
"""
def __init__(self, ridge: float = 1e-8):
super().__init__()
self._ridge = ridge
[docs] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Cosine distance between vectors, denominator regularized with ridge."""
ridge = self._ridge
x_norm = jnp.linalg.norm(x, axis=-1)
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge)
cosine_distance = 1.0 - cosine_similarity
# similarity is in [-1, 1], clip because of numerical imprecisions
return jnp.clip(cosine_distance, 0., 2.)
@classmethod
def _padder(cls, dim: int) -> jnp.ndarray:
return jnp.ones((1, dim))
[docs]@jax.tree_util.register_pytree_node_class
class Bures(CostFn):
"""Bures distance between a pair of (mean, cov matrix) raveled as vectors.
Args:
dimension: Dimensionality of the data.
kwargs: Keyword arguments for :func:`ott.math.matrix_square_root.sqrtm`.
"""
def __init__(self, dimension: int, **kwargs: Any):
super().__init__()
self._dimension = dimension
self._sqrtm_kw = kwargs
[docs] def norm(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute norm of Gaussian, sq. 2-norm of mean + trace of covariance."""
mean, cov = x_to_means_and_covs(x, self._dimension)
norm = jnp.sum(mean ** 2, axis=-1)
norm += jnp.trace(cov, axis1=-2, axis2=-1)
return norm
[docs] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute - 2 x Bures dot-product."""
mean_x, cov_x = x_to_means_and_covs(x, self._dimension)
mean_y, cov_y = x_to_means_and_covs(y, self._dimension)
mean_dot_prod = jnp.vdot(mean_x, mean_y)
sq_x = matrix_square_root.sqrtm(cov_x, self._dimension, **self._sqrtm_kw)[0]
sq_x_y_sq_x = jnp.matmul(sq_x, jnp.matmul(cov_y, sq_x))
sq__sq_x_y_sq_x = matrix_square_root.sqrtm(
sq_x_y_sq_x, self._dimension, **self._sqrtm_kw
)[0]
return -2 * (mean_dot_prod + jnp.trace(sq__sq_x_y_sq_x, axis1=-2, axis2=-1))
[docs] def covariance_fixpoint_iter(
self,
covs: jnp.ndarray,
weights: jnp.ndarray,
tolerance: float = 1e-4,
**kwargs: Any
) -> jnp.ndarray:
"""Iterate fix-point updates to compute barycenter of Gaussians.
Args:
covs: [batch, d^2] covariance matrices
weights: simplicial weights (nonnegative, sum to 1)
tolerance: tolerance of the overall fixed-point procedure
kwargs: parameters passed on to the sqrtm (Newton-Schulz)
algorithm to compute matrix square roots.
Returns:
a covariance matrix, the weighted Bures average of the covs matrices.
"""
@functools.partial(jax.vmap, in_axes=[None, 0, 0])
def scale_covariances(
cov_sqrt: jnp.ndarray, cov: jnp.ndarray, weight: jnp.ndarray
) -> jnp.ndarray:
"""Rescale covariance in barycenter step."""
return weight * matrix_square_root.sqrtm_only((cov_sqrt @ cov) @ cov_sqrt,
**kwargs)
def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool:
del iteration, constants
_, diff = state
return diff > tolerance
def body_fn(
iteration: int, constants: Tuple[Any, ...],
state: Tuple[jnp.ndarray, float], compute_error: bool
) -> Tuple[jnp.ndarray, float]:
del iteration, constants, compute_error
cov, _ = state
cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **kwargs)
scaled_cov = jnp.linalg.matrix_power(
jnp.sum(scale_covariances(cov_sqrt, covs, weights), axis=0), 2
)
next_cov = (cov_inv_sqrt @ scaled_cov) @ cov_inv_sqrt
diff = jnp.sum((next_cov - cov) ** 2) / jnp.prod(jnp.array(cov.shape))
return next_cov, diff
def init_state() -> Tuple[jnp.ndarray, float]:
cov_init = jnp.eye(self._dimension)
diff = jnp.inf
return cov_init, diff
# TODO(marcocuturi): ideally the integer parameters below should be passed
# by user, if one wants more fine grained control. This could clash with the
# parameters passed on to :func:`ott.math.matrix_square_root.sqrtm` by the
# barycenter call. At the moment, only `tolerance` can be used to control
# computational effort.
cov, _ = fixed_point_loop.fixpoint_iter(
cond_fn=cond_fn,
body_fn=body_fn,
min_iterations=1,
max_iterations=500,
inner_iterations=1,
constants=(),
state=init_state()
)
return cov
[docs] def barycenter(
self, weights: jnp.ndarray, xs: jnp.ndarray, **kwargs: Any
) -> jnp.ndarray:
"""Compute the Bures barycenter of weighted Gaussian distributions.
Implements the fixed point approach proposed in :cite:`alvarez-esteban:16`
for the computation of the mean and the covariance of the barycenter of
weighted Gaussian distributions.
Args:
weights: The barycentric weights.
xs: The points to be used in the computation of the barycenter, where
each point is described by a concatenation of the mean and the
covariance (raveled).
kwargs: Passed on to :meth:`covariance_fixpoint_iter`, and by extension to
:func:`ott.math.matrix_square_root.sqrtm`. Note that `tolerance` is used
for the fixed-point iteration of the barycenter, whereas `threshold` will apply to the fixed
point iteration of Newton-Schulz iterations.
Returns:
A concatenation of the mean and the raveled covariance of the barycenter.
"""
# Ensure that barycentric weights sum to 1.
weights = weights / jnp.sum(weights)
mus, covs = x_to_means_and_covs(xs, self._dimension)
mu_bary = jnp.sum(weights[:, None] * mus, axis=0)
cov_bary = self.covariance_fixpoint_iter(
covs=covs, weights=weights, **kwargs
)
barycenter = mean_and_cov_to_x(mu_bary, cov_bary, self._dimension)
return barycenter
@classmethod
def _padder(cls, dim: int) -> jnp.ndarray:
"""Pad with concatenated zero means and \
raveled identity covariance matrix."""
dimension = int((-1 + math.sqrt(1 + 4 * dim)) / 2)
padding = mean_and_cov_to_x(
jnp.zeros((dimension,)), jnp.eye(dimension), dimension
)
return padding[jnp.newaxis, :]
def tree_flatten(self):
return (), (self._dimension, self._sqrtm_kw)
@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0], **aux_data[1])
[docs]@jax.tree_util.register_pytree_node_class
class UnbalancedBures(CostFn):
"""Unbalanced Bures distance between two triplets of `(mass, mean, cov)`.
This cost uses the notation defined in :cite:`janati:20`, eq. 37, 39, 40.
Args:
dimension: Dimensionality of the data.
sigma: Entropic regularization.
gamma: KL-divergence regularization for the marginals.
kwargs: Keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`.
"""
def __init__(
self,
dimension: int,
*,
sigma: float = 1.0,
gamma: float = 1.0,
**kwargs: Any,
):
super().__init__()
self._dimension = dimension
self._sigma = sigma
self._gamma = gamma
self._sqrtm_kw = kwargs
[docs] def norm(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute norm of Gaussian for unbalanced Bures.
Args:
x: Array of shape ``[n_points + n_points + n_dim ** 2,]``, potentially
batched, corresponding to the raveled mass, means and the covariance
matrix.
Returns:
The norm, array of shape ``[]`` or ``[batch,]`` in the batched case.
"""
return self._gamma * x[..., 0]
[docs] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute dot-product for unbalanced Bures.
Args:
x: Array of shape ``[n_points + n_points + n_dim ** 2,]``
corresponding to the raveled mass, means and the covariance matrix.
y: Array of shape ``[n_points + n_points + n_dim ** 2,]``
corresponding to the raveled mass, means and the covariance matrix.
Returns:
The cost.
"""
# Sets a few constants
gam = self._gamma
sig2 = self._sigma ** 2
lam = sig2 + gam / 2.0
tau = gam / (2.0 * lam)
# Extracts mass, mean vector, covariance matrices
mass_x, mass_y = x[0], y[0]
mean_x, cov_x = x_to_means_and_covs(x[1:], self._dimension)
mean_y, cov_y = x_to_means_and_covs(y[1:], self._dimension)
diff_means = mean_x - mean_y
# Identity matrix of suitable size
iden = jnp.eye(self._dimension, dtype=x.dtype)
# Creates matrices needed in the computation
tilde_a = 0.5 * gam * (iden - lam * jnp.linalg.inv(cov_x + lam * iden))
tilde_b = 0.5 * gam * (iden - lam * jnp.linalg.inv(cov_y + lam * iden))
tilde_a_b = jnp.matmul(tilde_a, tilde_b)
c_mat = matrix_square_root.sqrtm(
1 / tau * tilde_a_b + 0.25 * (sig2 ** 2) * iden, **self._sqrtm_kw
)[0]
c_mat -= 0.5 * sig2 * iden
# Computes log determinants (their sign should be >0).
sldet_c, ldet_c = jnp.linalg.slogdet(c_mat)
sldet_t_ab, ldet_t_ab = jnp.linalg.slogdet(tilde_a_b)
sldet_ab, ldet_ab = jnp.linalg.slogdet(jnp.matmul(cov_x, cov_y))
sldet_c_ab, ldet_c_ab = jnp.linalg.slogdet(c_mat - 2.0 * tilde_a_b / gam)
# Gathers all these results to compute log total mass of transport
log_m_pi = (0.5 * self._dimension * sig2 / (gam + sig2)) * jnp.log(sig2)
log_m_pi += (1.0 / (tau + 1.0)) * (
jnp.log(mass_x) + jnp.log(mass_y) + ldet_c + 0.5 *
(tau * ldet_t_ab - ldet_ab)
)
log_m_pi += -jnp.sum(
diff_means * jnp.linalg.solve(cov_x + cov_y + lam * iden, diff_means)
) / (2.0 * (tau + 1.0))
log_m_pi += -0.5 * ldet_c_ab
# if all logdet signs are 1, output value, nan otherwise
pos_signs = (sldet_c + sldet_c_ab + sldet_t_ab + sldet_t_ab) == 4
return jax.lax.cond(
pos_signs, lambda: 2 * sig2 * mass_x * mass_y - 2 *
(sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan
)
def tree_flatten(self):
return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw)
@classmethod
def tree_unflatten(cls, aux_data, children):
del children
dim, sigma, gamma, kwargs = aux_data
return cls(dim, sigma=sigma, gamma=gamma, **kwargs)
def x_to_means_and_covs(x: jnp.ndarray,
dimension: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Extract means and covariance matrices of Gaussians from raveled vector.
Args:
x: [num_gaussians, dimension, (1 + dimension)] array of concatenated means
and covariances (raveled) dimension: the dimension of the Gaussians.
Returns:
means: [num_gaussians, dimension] array that holds the means.
covariances: [num_gaussians, dimension] array that holds the covariances.
"""
x = jnp.atleast_2d(x)
means = x[:, :dimension]
covariances = jnp.reshape(
x[:, dimension:dimension + dimension ** 2], (-1, dimension, dimension)
)
return jnp.squeeze(means), jnp.squeeze(covariances)
def mean_and_cov_to_x(
mean: jnp.ndarray, covariance: jnp.ndarray, dimension: int
) -> jnp.ndarray:
"""Ravel a Gaussian's mean and covariance matrix to d(1 + d) vector."""
x = jnp.concatenate((mean, jnp.reshape(covariance, (dimension * dimension))))
return x