ott.geometry.costs.Bures#

class ott.geometry.costs.Bures(dimension, **kwargs)[source]#

Bures distance between a pair of (mean, cov matrix) raveled as vectors.

Parameters

Methods

all_pairs(x, y)

Compute matrix of all costs (including norms) for vectors in x / y.

all_pairs_pairwise(x, y)

Compute matrix of all pairwise-costs (no norms) for vectors in x / y.

barycenter(weights, xs, **kwargs)

Compute the Bures barycenter of weighted Gaussian distributions.

covariance_fixpoint_iter(covs, weights[, ...])

Iterate fix-point updates to compute barycenter of Gaussians.

norm(x)

Compute norm of Gaussian, sq.

pairwise(x, y)

Compute - 2 x Bures dot-product.