ott.problems.linear.barycenter_problem.FreeBarycenterProblem

ott.problems.linear.barycenter_problem.FreeBarycenterProblem#

class ott.problems.linear.barycenter_problem.FreeBarycenterProblem(y, b=None, weights=None, cost_fn=None, epsilon=None, **kwargs)[source]#

Free Wasserstein barycenter problem [Cuturi and Doucet, 2014].

Parameters:
  • y (Array) – Array of shape [num_total_points, ndim] merging the points of all measures. Alternatively, already segmented array of shape [num_measures, max_measure_size, ndim] can be passed. See also segment_point_cloud().

  • b (Optional[Array]) – Array of shape [num_total_points,] containing the weights of all the points within the measures that define the barycenter problem. Same as y, pre-segmented array of weights of shape [num_measures, max_measure_size] can be passed. If y is already pre-segmented, this array must be always specified.

  • weights (Optional[Array]) – Array of shape [num_measures,] containing the weights of the measures.

  • cost_fn (Optional[CostFn]) – Cost function used. If None, use the SqEuclidean cost.

  • epsilon (Optional[float]) – Epsilon regularization used to solve reg-OT problems.

  • kwargs (Any) –

    Keyword arguments segment_point_cloud(). Only used when y is not already segmented. When passing segment_ids, 2 arguments must be specified for jitting to work:

    • num_segments - the total number of measures.

    • max_measure_size - maximum of support sizes of these measures.

Methods

Attributes

flattened_b

Array of shape [num_measures * (N_1 + N_2 + ...),].

flattened_y

Array of shape [num_measures * (N_1 + N_2 + ...), ndim].

max_measure_size

Maximum number of points across all measures.

ndim

Number of dimensions of y.

num_measures

Number of measures.

segmented_y_b

Tuple of arrays containing the segmented measures and weights.

weights

Barycenter weights of shape [num_measures,] that sum to 1.