ott.tools.gaussian_mixture.gaussian_mixture_pair.GaussianMixturePair

ott.tools.gaussian_mixture.gaussian_mixture_pair.GaussianMixturePair#

class ott.tools.gaussian_mixture.gaussian_mixture_pair.GaussianMixturePair(gmm0, gmm1, epsilon=0.01, tau=1.0, lock_gmm1=False)[source]#

Coupled pair of Gaussian mixture models.

Includes methods used in estimating an optimal pairing between GMM components using the Wasserstein-like method described in [Delon and Desolneux, 2020], as well as generalization that allows for the reweighting of components.

[Delon and Desolneux, 2020] propose fitting a pair of GMMs to a pair of point clouds in such a way that the sum of the log likelihood of the points minus a weighted penalty involving a Wasserstein-like distance between the GMMs. Their proposed algorithm involves using EM in which a balanced Sinkhorn algorithm is used to estimate a coupling between the GMMs at each step of EM.

Our generalization of this algorithm allows for a mismatch between the marginals of the coupling and the GMM component weights. This mismatch can be interpreted as components being reweighted rather than being transported. We penalize reweighting with a generalized KL-divergence penalty, and we give the option to use the unbalanced Sinkhorn algorithm rather than the balanced to compute the divergence between GMMs.

Parameters:

Methods

get_bures_geometry()

Get a Bures Geometry for the two GMMs.

get_cost_matrix()

Get matrix of \(W_2^2\) costs between all pairs of components.

get_normalized_sinkhorn_coupling(sinkhorn_output)

Get the normalized coupling matrix for the specified Sinkhorn output.

get_sinkhorn(cost_matrix, **kwargs)

Get the output of Sinkhorn's method for a given cost matrix.

Attributes

dtype

epsilon

gmm0

gmm1

lock_gmm1

rho

tau