ott.problems.quadratic.gw_barycenter.GWBarycenterProblem#
- class ott.problems.quadratic.gw_barycenter.GWBarycenterProblem(y=None, b=None, weights=None, costs=None, y_fused=None, fused_penalty=1.0, gw_loss='sqeucl', scale_cost=1.0, **kwargs)[source]#
(Fused) Gromov-Wasserstein barycenter problem [Peyré et al., 2016, Titouan et al., 2019].
- Parameters:
y (
Optional[Array]) – Array of shape[num_total_points, ndim]merging the points of all measures. Alternatively, already segmented array of shape[num_measures, max_measure_size, ndim]can be passed. See alsosegment_point_cloud().b (
Optional[Array]) – Array of shape[num_total_points,]containing the weights of all the points within the measures that define the barycenter problem. Same asy, pre-segmented array of weights of shape[num_measures, max_measure_size]can be passed. Ifyis already pre-segmented, this array must be passed.weights (
Optional[Array]) – Array of shape[num_measures,]containing the weights of the barycenter problem.costs (
Optional[Array]) – Alternative toy, an array of shape[num_measures, max_measure_size, max_measure_size]that defines padded cost matrices for each measure. Used in the quadratic term. Only one ofyandcostcan be specified.y_fused (
Optional[Array]) – Array of shape[num_total_points, ndim_fused]containing the data of the points of all measures used to define the linear term in the fused case. Same asy, it can be specified as a pre-segmented array of shape[num_measures, max_measure_size, ndim_fused].gw_loss (
Literal['sqeucl','kl']) – Gromov-Wasserstein loss.fused_penalty (
float) – Multiplier of the linear term. Only used wheny_fused != None.scale_cost (
Union[float,Literal['mean','max_cost']]) – Scaling of cost matrices passed to geometries.kwargs (
Any) – Keyword arguments forBarycenterProblem.
Methods
update_barycenter(transports, a)Update the barycenter cost matrix.
update_features(transports, a)Update the barycenter features in the fused case [Titouan et al., 2019].
Attributes
Array of shape
[num_measures * (N_1 + N_2 + ...),].Array of shape
[num_measures * (N_1 + N_2 + ...), ndim].Gromov-Wasserstein loss.
Whether the problem is fused.
Maximum number of points across all measures.
Number of dimensions of
y.Number of dimensions of the fused term.
Number of measures.
[num_measures,]array containing number of points per measure.Tuple of arrays containing segmented measures, weights, # of points.
Feature array of shape used in the fused case.
Barycenter weights of shape
[num_measures,]that sum to 1.