ott.geometry.costs.Bures
ott.geometry.costs.Bures#
- class ott.geometry.costs.Bures(dimension, **kwargs)[source]#
Bures distance between a pair of (mean, cov matrix) raveled as vectors.
Methods
all_pairs
(x, y)Compute matrix of all costs (including norms) for vectors in x / y.
all_pairs_pairwise
(x, y)Compute matrix of all pairwise-costs (no norms) for vectors in x / y.
barycenter
(weights, xs)Compute the Bures barycenter of weighted Gaussian distributions.
covariance_fixpoint_iter
(covs, lambdas[, rtol])Iterate fix-point updates to compute barycenter of Gaussians.
means_and_covs_to_x
(mean, covariance)Vectorized version of means_and_covs_to_x.
norm
(x)Compute norm of Gaussian, sq.
pairwise
(x, y)Compute - 2 x Bures dot-product.
relative_diff
(x, y)Monitor change in two successive estimates of matrices.
scale_covariances
(cov_sqrt, cov_i, lambda_i)Vectorized version of scale_covariances.
Extract mean and covariance matrix from raveled d(1 + d) vector.