ott.geometry.costs.Bures

Contents

ott.geometry.costs.Bures#

class ott.geometry.costs.Bures(dimension, sqrtm_kw=None)[source]#

Bures distance between a pair of (mean, covariance matrix).

Parameters:
  • dimension (int) – Dimensionality of the data.

  • sqrtm_kw (Optional[Dict[str, Any]]) – Dictionary of keyword arguments to control the behavior of inner calls to sqrtm().

Methods

all_pairs(x, y)

Compute matrix of all pairwise costs, including the norms.

all_pairs_pairwise(x, y)

Compute matrix of all pairwise costs, excluding the norms.

barycenter(weights, xs[, tolerance, sqrtm_kw])

Compute the Bures barycenter of weighted Gaussian distributions.

covariance_fixpoint_iter(covs, weights[, ...])

Iterate fix-point updates to compute barycenter of Gaussians.

norm(x)

Compute norm of Gaussian, sq.

pairwise(x, y)

Compute - 2 x Bures dot-product.

twist_operator(vec, dual_vec, variable)

Twist inverse operator of the cost function.