Source code for ott.problems.linear.barycenter_problem

# Copyright OTT-JAX
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp

from ott.geometry import costs, geometry, segment

__all__ = ["FreeBarycenterProblem", "FixedBarycenterProblem"]

[docs] @jax.tree_util.register_pytree_node_class class FreeBarycenterProblem: """Free Wasserstein barycenter problem :cite:`cuturi:14`. Args: y: Array of shape ``[num_total_points, ndim]`` merging the points of all measures. Alternatively, already segmented array of shape ``[num_measures, max_measure_size, ndim]`` can be passed. See also :func:`~ott.geometry.segment.segment_point_cloud`. b: Array of shape ``[num_total_points,]`` containing the weights of all the points within the measures that define the barycenter problem. Same as ``y``, pre-segmented array of weights of shape ``[num_measures, max_measure_size]`` can be passed. If ``y`` is already pre-segmented, this array must be always specified. weights: Array of shape ``[num_measures,]`` containing the weights of the measures. cost_fn: Cost function used. If `None`, use the :class:`~ott.geometry.costs.SqEuclidean` cost. epsilon: Epsilon regularization used to solve reg-OT problems. kwargs: Keyword arguments :func:`~ott.geometry.segment.segment_point_cloud`. Only used when ``y`` is not already segmented. When passing ``segment_ids``, 2 arguments must be specified for jitting to work: - ``num_segments`` - the total number of measures. - ``max_measure_size`` - maximum of support sizes of these measures. """ def __init__( self, y: jnp.ndarray, b: Optional[jnp.ndarray] = None, weights: Optional[jnp.ndarray] = None, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, **kwargs: Any, ): self._y = y if y.ndim == 3 and b is None: raise ValueError("Specify weights if `y` is already segmented.") self._b = b self._weights = weights self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn self.epsilon = epsilon self._kwargs = kwargs if self._is_segmented: # (num_measures, max_measure_size, ndim) # (num_measures, max_measure_size) assert self._y.shape[:2] == self._b.shape else: # (num_total_points, ndim) # (num_total_points,) assert self._b is None or self._y.shape[0] == self._b.shape[0] @property def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: """Tuple of arrays containing the segmented measures and weights. - Segmented measures of shape ``[num_measures, max_measure_size, ndim]``. - Segmented weights of shape ``[num_measures, max_measure_size]``. """ if self._is_segmented: y, b = self._y, self._b else: y, b = segment.segment_point_cloud( x=self._y, a=self._b, padding_vector=self.cost_fn._padder(self.ndim), **self._kwargs ) return y, b @property def flattened_y(self) -> jnp.ndarray: """Array of shape ``[num_measures * (N_1 + N_2 + ...), ndim]``.""" if self._is_segmented: return self._y.reshape((-1, self._y.shape[-1])) return self._y @property def flattened_b(self) -> Optional[jnp.ndarray]: """Array of shape ``[num_measures * (N_1 + N_2 + ...),]``.""" return None if self._b is None else self._b.ravel() @property def num_measures(self) -> int: """Number of measures.""" return self.segmented_y_b[0].shape[0] @property def max_measure_size(self) -> int: """Maximum number of points across all measures.""" return self.segmented_y_b[0].shape[1] @property def ndim(self) -> int: """Number of dimensions of ``y``.""" return self._y.shape[-1] @property def weights(self) -> jnp.ndarray: """Barycenter weights of shape ``[num_measures,]`` that sum to 1.""" if self._weights is None: return jnp.ones((self.num_measures,)) / self.num_measures # Check that the number of measures coincides with the weights' size. assert self._weights.shape[0] == self.num_measures # By default, we assume that weights sum to 1, and enforce this if needed. return self._weights / jnp.sum(self._weights) @property def _is_segmented(self) -> bool: return self._y.ndim == 3 def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self._y, self._b, self._weights], { "cost_fn": self.cost_fn, "epsilon": self.epsilon, **self._kwargs, }) @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "FreeBarycenterProblem": y, b, weights = children return cls(y=y, b=b, weights=weights, **aux_data)
[docs] @jax.tree_util.register_pytree_node_class class FixedBarycenterProblem: """Fixed-support Wasserstein barycenter problem. Args: geom: Geometry object. a: batch of histograms of shape ``[batch, num_a]`` where ``num_a`` matches the first value of the :attr:`~ott.geometry.Geometry.shape` attribute of ``geom``. weights: ``[batch,]`` positive weights summing to :math:`1`. Uniform by default. """ def __init__( self, geom: geometry.Geometry, a: jnp.ndarray, weights: Optional[jnp.ndarray] = None, ): self.geom = geom self.a = a self._weights = weights @property def num_measures(self) -> int: """Number of measures.""" return self.a.shape[0] @property def weights(self) -> jnp.ndarray: """Barycenter weights of shape ``[num_measures,]`` that sum to :math`1`.""" if self._weights is None: return jnp.ones((self.num_measures,)) / self.num_measures # check that the number of measures coincides with the weights' size assert self._weights.shape[0] == self.num_measures # by default, we assume that weights sum to 1, and enforce this if needed return self._weights / jnp.sum(self._weights) def tree_flatten(self): # noqa: D102 return [self.geom, self.a, self._weights], None @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "FixedBarycenterProblem": del aux_data geom, a, weights = children return cls(geom=geom, a=a, weights=weights)