ott.tools.gaussian_mixture.gaussian_mixture_pair.GaussianMixturePair#
- class ott.tools.gaussian_mixture.gaussian_mixture_pair.GaussianMixturePair(gmm0, gmm1, epsilon=0.01, tau=1.0, lock_gmm1=False)[source]#
Coupled pair of Gaussian mixture models.
Includes methods used in estimating an optimal pairing between GMM components using the Wasserstein-like method described in [Delon and Desolneux, 2020], as well as generalization that allows for the reweighting of components.
[Delon and Desolneux, 2020] propose fitting a pair of GMMs to a pair of point clouds in such a way that the sum of the log likelihood of the points minus a weighted penalty involving a Wasserstein-like distance between the GMMs. Their proposed algorithm involves using EM in which a balanced Sinkhorn algorithm is used to estimate a coupling between the GMMs at each step of EM.
Our generalization of this algorithm allows for a mismatch between the marginals of the coupling and the GMM component weights. This mismatch can be interpreted as components being reweighted rather than being transported. We penalize reweighting with a generalized KL-divergence penalty, and we give the option to use the unbalanced Sinkhorn algorithm rather than the balanced to compute the divergence between GMMs.
- Parameters:
gmm0 (
GaussianMixture
)gmm1 (
GaussianMixture
)epsilon (
float
)tau (
float
)lock_gmm1 (
bool
)
Methods
Get a Bures Geometry for the two GMMs.
Get matrix of \(W_2^2\) costs between all pairs of components.
get_normalized_sinkhorn_coupling
(sinkhorn_output)Get the normalized coupling matrix for the specified Sinkhorn output.
get_sinkhorn
(cost_matrix, **kwargs)Get the output of Sinkhorn's method for a given cost matrix.
Attributes