ott.geometry.costs.Bures#

class ott.geometry.costs.Bures(dimension, **kwargs)[source]#

Bures distance between a pair of (mean, cov matrix) raveled as vectors.

Methods

all_pairs(x, y)

Compute matrix of all costs (including norms) for vectors in x / y.

all_pairs_pairwise(x, y)

Compute matrix of all pairwise-costs (no norms) for vectors in x / y.

barycenter(weights, xs)

Compute the Bures barycenter of weighted Gaussian distributions.

covariance_fixpoint_iter(covs, lambdas[, rtol])

Iterate fix-point updates to compute barycenter of Gaussians.

norm(x)

Compute norm of Gaussian, sq.

padder(dim)

Pad with concatenated zero means and raveled identity covariance matrix.

pairwise(x, y)

Compute - 2 x Bures dot-product.

relative_diff(x, y)

Monitor change in two successive estimates of matrices.

scale_covariances(cov_sqrt, cov_i, lambda_i)

Vectorized version of scale_covariances.

Parameters
  • dimension (int) –

  • kwargs (Any) –