ott.problems.quadratic.gw_barycenter.GWBarycenterProblem#
- class ott.problems.quadratic.gw_barycenter.GWBarycenterProblem(y=None, b=None, weights=None, costs=None, y_fused=None, fused_penalty=1.0, gw_loss='sqeucl', scale_cost=1.0, **kwargs)[source]#
(Fused) Gromov-Wasserstein barycenter problem [Peyré et al., 2016, Titouan et al., 2019].
- Parameters:
y (
Optional
[Array
]) – Array of shape[num_total_points, ndim]
merging the points of all measures. Alternatively, already segmented array of shape[num_measures, max_measure_size, ndim]
can be passed. See alsosegment_point_cloud()
.b (
Optional
[Array
]) – Array of shape[num_total_points,]
containing the weights of all the points within the measures that define the barycenter problem. Same asy
, pre-segmented array of weights of shape[num_measures, max_measure_size]
can be passed. Ify
is already pre-segmented, this array must be passed.weights (
Optional
[Array
]) – Array of shape[num_measures,]
containing the weights of the barycenter problem.costs (
Optional
[Array
]) – Alternative toy
, an array of shape[num_measures, max_measure_size, max_measure_size]
that defines padded cost matrices for each measure. Used in the quadratic term. Only one ofy
andcost
can be specified.y_fused (
Optional
[Array
]) – Array of shape[num_total_points, ndim_fused]
containing the data of the points of all measures used to define the linear term in the fused case. Same asy
, it can be specified as a pre-segmented array of shape[num_measures, max_measure_size, ndim_fused]
.gw_loss (
Literal
['sqeucl'
,'kl'
]) – Gromov-Wasserstein loss.fused_penalty (
float
) – Multiplier of the linear term. Only used wheny_fused != None
.scale_cost (
Union
[float
,Literal
['mean'
,'max_cost'
]]) – Scaling of cost matrices passed to geometries.kwargs (
Any
) – Keyword arguments forBarycenterProblem
.
Methods
update_barycenter
(transports, a)Update the barycenter cost matrix.
update_features
(transports, a)Update the barycenter features in the fused case [Titouan et al., 2019].
Attributes
Array of shape
[num_measures * (N_1 + N_2 + ...),]
.Array of shape
[num_measures * (N_1 + N_2 + ...), ndim]
.Gromov-Wasserstein loss.
Whether the problem is fused.
Maximum number of points across all measures.
Number of dimensions of
y
.Number of dimensions of the fused term.
Number of measures.
Tuple of arrays containing the segmented measures and weights.
Feature array of shape used in the fused case.
Barycenter weights of shape
[num_measures,]
that sum to 1.