ott.core.bar_problems.BarycenterProblem#

class ott.core.bar_problems.BarycenterProblem(y, b=None, weights=None, cost_fn=None, epsilon=None, debiased=False, **kwargs)[source]#

Wasserstein barycenter problem [Cuturi and Doucet, 2014].

Parameters
  • y (ndarray) – Array of shape [num_total_points, ndim] merging the points of all measures. Alternatively, already segmented array of shape [num_measures, max_measure_size, ndim] can be passed. See also segment_point_cloud().

  • b (Optional[ndarray]) – Array of shape [num_total_points,] containing the weights of all the points within the measures that define the barycenter problem. Similarly as y, segmented array of weights of shape [num_measures, max_measure_size] can be passed. If y is already pre-segmented, this array must be always specified.

  • weights (Optional[ndarray]) – Array of shape [num_measures,] containing the weights of the measures.

  • cost_fn (Optional[CostFn]) – Cost function used. If None, use Euclidean cost.

  • epsilon (Optional[float]) – Epsilon regularization used to solve reg-OT problems.

  • debiased (bool) – Currently not implemented. Whether the problem is debiased, in the sense that the regularized transportation cost of barycenter to itself will be considered when computing gradient. Note that if the debiased option is used, the barycenter size in init_state() needs to be smaller than the maximum measure size for parallelization to operate efficiently.

  • kwargs (Any) –

    Keyword arguments segment_point_cloud(). Only used when y is not already segmented. When passing segment_ids, 2 arguments must be specified for jitting to work:

    • num_segments - the total number of measures.

    • max_measure_size - maximum of support sizes of these measures.

Methods

Attributes

flattened_b

Array of shape [num_measures * (N_1 + N_2 + ...),].

flattened_y

Array of shape [num_measures * (N_1 + N_2 + ...), ndim].

max_measure_size

Maximum number of points across all measures.

ndim

Number of dimensions of y.

num_measures

Number of measures.

segmented_y_b

Tuple of arrays containing the segmented measures and weights.

weights

Barycenter weights of shape [num_measures,] that sum to 1.