class ott.core.bar_problems.BarycenterProblem(y, b=None, weights=None, cost_fn=None, epsilon=None, debiased=False, **kwargs)[source]#

Wasserstein barycenter problem [Cuturi and Doucet, 2014].

  • y (ndarray) – Array of shape [num_total_points, ndim] merging the points of all measures. Alternatively, already segmented array of shape [num_measures, max_measure_size, ndim] can be passed. See also segment_point_cloud().

  • b (Optional[ndarray]) – Array of shape [num_total_points,] containing the weights of all the points within the measures that define the barycenter problem. Similarly as y, segmented array of weights of shape [num_measures, max_measure_size] can be passed. If y is already pre-segmented, this array must be always specified.

  • weights (Optional[ndarray]) – Array of shape [num_measures,] containing the weights of the measures.

  • cost_fn (Optional[CostFn]) – Cost function used. If None, use Euclidean cost.

  • epsilon (Optional[float]) – Epsilon regularization used to solve reg-OT problems.

  • debiased (bool) – Currently not implemented. Whether the problem is debiased, in the sense that the regularized transportation cost of barycenter to itself will be considered when computing gradient. Note that if the debiased option is used, the barycenter size in init_state() needs to be smaller than the maximum measure size for parallelization to operate efficiently.

  • kwargs (Any) –

    Keyword arguments segment_point_cloud(). Only used when y is not already segmented. When passing segment_ids, 2 arguments must be specified for jitting to work:

    • num_segments - the total number of measures.

    • max_measure_size - maximum of support sizes of these measures.




Array of shape [num_measures * (N_1 + N_2 + ...),].


Array of shape [num_measures * (N_1 + N_2 + ...), ndim].


Maximum number of points across all measures.


Number of dimensions of y.


Number of measures.


Tuple of arrays containing the segmented measures and weights.


Barycenter weights of shape [num_measures,] that sum to 1.