ott.problems.linear.barycenter_problem.FreeBarycenterProblem#
- class ott.problems.linear.barycenter_problem.FreeBarycenterProblem(y, b=None, weights=None, cost_fn=None, epsilon=None, debiased=False, **kwargs)[source]#
Free Wasserstein barycenter problem [Cuturi and Doucet, 2014].
- Parameters:
y (
Array
) – Array of shape[num_total_points, ndim]
merging the points of all measures. Alternatively, already segmented array of shape[num_measures, max_measure_size, ndim]
can be passed. See alsosegment_point_cloud()
.b (
Optional
[Array
]) – Array of shape[num_total_points,]
containing the weights of all the points within the measures that define the barycenter problem. Same asy
, pre-segmented array of weights of shape[num_measures, max_measure_size]
can be passed. Ify
is already pre-segmented, this array must be always specified.weights (
Optional
[Array
]) – Array of shape[num_measures,]
containing the weights of the measures.cost_fn (
Optional
[CostFn
]) – Cost function used. If None, use theSqEuclidean
cost.epsilon (
Optional
[float
]) – Epsilon regularization used to solve reg-OT problems.debiased (
bool
) – Currently not implemented. Whether the problem is debiased, in the sense that the regularized transportation cost of barycenter to itself will be considered when computing gradient. Note that if the debiased option is used, the barycenter size needs to be smaller than the maximum measure size for parallelization to operate efficiently.kwargs (
Any
) –Keyword arguments
segment_point_cloud()
. Only used wheny
is not already segmented. When passingsegment_ids
, 2 arguments must be specified for jitting to work:num_segments
- the total number of measures.max_measure_size
- maximum of support sizes of these measures.
Methods
Attributes
Array of shape
[num_measures * (N_1 + N_2 + ...),]
.Array of shape
[num_measures * (N_1 + N_2 + ...), ndim]
.Maximum number of points across all measures.
Number of dimensions of
y
.Number of measures.
Tuple of arrays containing the segmented measures and weights.
Barycenter weights of shape
[num_measures,]
that sum to 1.