ott.problems.quadratic.gw_barycenter.GWBarycenterProblem
ott.problems.quadratic.gw_barycenter.GWBarycenterProblem#
- class ott.problems.quadratic.gw_barycenter.GWBarycenterProblem(y=None, b=None, weights=None, costs=None, y_fused=None, fused_penalty=1.0, gw_loss='sqeucl', scale_cost=1.0, **kwargs)[source]#
(Fused) Gromov-Wasserstein barycenter problem [Peyré et al., 2016, Titouan et al., 2019].
- Parameters
y (
Optional
[Array
]) – Array of shape[num_total_points, ndim]
merging the points of all measures. Alternatively, already segmented array of shape[num_measures, max_measure_size, ndim]
can be passed. See alsosegment_point_cloud()
.b (
Optional
[Array
]) – Array of shape[num_total_points,]
containing the weights of all the points within the measures that define the barycenter problem. Same asy
, pre-segmented array of weights of shape[num_measures, max_measure_size]
can be passed. Ify
is already pre-segmented, this array must be passed.weights (
Optional
[Array
]) – Array of shape[num_measures,]
containing the weights of the barycenter problem.costs (
Optional
[Array
]) – Alternative toy
, an array of shape[num_measures, max_measure_size, max_measure_size]
that defines padded cost matrices for each measure. Used in the quadratic term. Only one ofy
andcost
can be specified.y_fused (
Optional
[Array
]) – Array of shape[num_total_points, ndim_fused]
containing the data of the points of all measures used to define the linear term in the fused case. Same asy
, it can be specified as a pre-segmented array of shape[num_measures, max_measure_size, ndim_fused]
.gw_loss (
Literal
[‘sqeucl’, ‘kl’]) – Gromov-Wasserstein loss.fused_penalty (
float
) – Multiplier of the linear term. Only used wheny_fused != None
.scale_cost (
Union
[int
,float
,Literal
[‘mean’, ‘max_cost’]]) – Scaling of cost matrices passed to geometries.kwargs (
Any
) – Keyword arguments forBarycenterProblem
.
Methods
update_barycenter
(transports, a)Update the barycenter cost matrix.
update_features
(transports, a)Update the barycenter features in the fused case [Titouan et al., 2019].
Attributes
Array of shape
[num_measures * (N_1 + N_2 + ...),]
.Array of shape
[num_measures * (N_1 + N_2 + ...), ndim]
.Gromov-Wasserstein loss.
Whether the problem is fused.
Maximum number of points across all measures.
Number of dimensions of
y
.Number of dimensions of the fused term.
Number of measures.
Tuple of arrays containing the segmented measures and weights.
Feature array of shape
[num_measures, max_measure_size, ndim_fused]
used in the fused case.Barycenter weights of shape
[num_measures,]
that sum to 1.