ott.geometry.costs.UnbalancedBures#

class ott.geometry.costs.UnbalancedBures(dimension, gamma=1.0, sigma=1.0, **kwargs)[source]#

Regularized/unbalanced Bures dist between two triplets of (mass,mean,cov).

This cost implements the value defined in https://arxiv.org/pdf/2006.02572.pdf Equation 37, 39, 40. We follow their notations. It is assumed inputs are given as triplets (mass, mean, covariance) raveled as vectors, in that order.

Methods

all_pairs(x, y)

Compute matrix of all costs (including norms) for vectors in x / y.

all_pairs_pairwise(x, y)

Compute matrix of all pairwise-costs (no norms) for vectors in x / y.

barycenter(weights, xs)

rtype

float

norm(x)

Compute norm of Gaussian for unbalanced Bures.

pairwise(x, y)

Compute dot-product for unbalanced Bures.

Parameters