ott.geometry.costs.UnbalancedBures

Contents

ott.geometry.costs.UnbalancedBures#

class ott.geometry.costs.UnbalancedBures(dimension, *, sigma=1.0, gamma=1.0, **kwargs)[source]#

Unbalanced Bures distance between two triplets of (mass, mean, cov).

This cost uses the notation defined in [Janati et al., 2020], eq. 37, 39, 40.

Parameters:
  • dimension (int) – Dimensionality of the data.

  • sigma (float) – Entropic regularization.

  • gamma (float) – KL-divergence regularization for the marginals.

  • kwargs (Any) – Keyword arguments for sqrtm().

Methods

all_pairs(x, y)

Compute matrix of all pairwise costs, including the norms.

all_pairs_pairwise(x, y)

Compute matrix of all pairwise costs, excluding the norms.

barycenter(weights, xs)

Barycentric operator.

norm(x)

Compute norm of Gaussian for unbalanced Bures.

pairwise(x, y)

Compute dot-product for unbalanced Bures.

twist_operator(vec, dual_vec, variable)

Twist inverse operator of the cost function.