ott.geometry.costs.UnbalancedBures
ott.geometry.costs.UnbalancedBures#
- class ott.geometry.costs.UnbalancedBures(dimension, gamma=1.0, sigma=1.0, **kwargs)[source]#
Regularized/unbalanced Bures dist between two triplets of (mass,mean,cov).
This cost implements the value defined in https://arxiv.org/pdf/2006.02572.pdf Equation 37, 39, 40. We follow their notations. It is assumed inputs are given as triplets (mass, mean, covariance) raveled as vectors, in that order.
Methods
all_pairs
(x, y)Compute matrix of all costs (including norms) for vectors in x / y.
all_pairs_pairwise
(x, y)Compute matrix of all pairwise-costs (no norms) for vectors in x / y.
barycenter
(weights, xs)- rtype
norm
(x)Compute norm of Gaussian for unbalanced Bures.
pairwise
(x, y)Compute dot-product for unbalanced Bures.