ott.core.bar_problems.GWBarycenterProblem#

class ott.core.bar_problems.GWBarycenterProblem(y=None, b=None, weights=None, costs=None, y_fused=None, fused_penalty=1.0, gw_loss='sqeucl', scale_cost=1.0, **kwargs)[source]#

(Fused) Gromov-Wasserstein barycenter problem [Peyré et al., 2016, Titouan et al., 2019].

Parameters
  • y (Optional[ndarray]) – Array of shape [num_total_points, ndim] merging the points of all measures. Alternatively, already segmented array of shape [num_measures, max_measure_size, ndim] can be passed. See also segment_point_cloud().

  • b (Optional[ndarray]) – Array of shape [num_total_points,] containing the weights of all the points within the measures that define the barycenter problem. Similarly as y, segmented array of weights of shape [num_measures, max_measure_size] can be passed. If y is already pre-segmented, this array must be passed.

  • weights (Optional[ndarray]) – Array of shape [num_measures,] containing the weights of the barycenter problem.

  • costs (Optional[ndarray]) – Alternative to y, an array of shape [num_measures, max_measure_size, max_measure_size] that defines padded cost matrices for each measure. Used in the quadratic term. Only one of y and cost can be specified.

  • y_fused (Optional[ndarray]) – Array of shape [num_total_points, ndim_fused] containing the data of the points of all measures used to define the linear term in the fused case. Similarly as y, can be specified as a pre-segmented array of shape [num_measures, max_measure_size, ndim_fused].

  • gw_loss (Literal[‘sqeucl’, ‘kl’]) – Gromov-Wasserstein loss.

  • fused_penalty (float) – Multiplier of the linear term. Only used when y_fused != None.

  • scale_cost (Union[int, float, Literal[‘mean’, ‘max_cost’]]) – Scaling of cost matrices passed to geometries.

  • kwargs (Any) – Keyword arguments for BarycenterProblem.

Methods

update_barycenter(transports, a)

Update the barycenter cost matrix.

update_features(transports, a)

Update the barycenter features in the fused case [Titouan et al., 2019].

Attributes

flattened_b

Array of shape [num_measures * (N_1 + N_2 + ...),].

flattened_y

Array of shape [num_measures * (N_1 + N_2 + ...), ndim].

gw_loss

Gromov-Wasserstein loss.

is_fused

Whether the problem is fused.

max_measure_size

Maximum number of points across all measures.

ndim

Number of dimensions of y.

ndim_fused

Number of dimensions of the fused term.

num_measures

Number of measures.

segmented_y_b

Tuple of arrays containing the segmented measures and weights.

segmented_y_fused

Feature array of shape [num_measures, max_measure_size, ndim_fused] used in the fused case.

weights

Barycenter weights of shape [num_measures,] that sum to 1.