# Fitting Pairs of Coupled GMMs#

Several papers have recently proposed a Wasserstein-like distance measure between Gaussian mixture models . The idea is that:

1. there is an analytic solution for the Wasserstein distance between two Gaussians, and

2. if one limits the set of allowed couplings between GMMs to the space of Gaussian mixtures, one can define a Wasserstein-like distance between a pair of GMMs in terms of the Wasserstein distance between their components.

In , the distance $$MW_2$$ between two GMMs, $$\mu_0$$ and $$\mu_1$$, is defined as follows:

$MW_2^2(\mu_0, \mu_1) = \inf_{\gamma\in \Pi(\mu_0, \mu_1) \cap GMM_{2d}(\infty)} \int_{\mathbb{R}^d\times \mathbb{R}^d} \|y_0-y_1\|^2 d\gamma(y_0, y_1)$

where $$\Pi(\mu_0, \mu_1)$$ is the set of probability measures on $$(\mathbb{R}^d)^2$$ having $$\mu_0$$ and $$\mu_1$$ as marginals, and $$GMM_d(K)$$ is the set of Gaussian mixtures in $$\mathbb{R}^d$$ with less than $$K$$ components (see (4.1)).

$MW_2^2(\mu_0, \mu_1) = \min_{w \in \Pi(\pi_0, \pi_1)} \sum_{k,l} w_{kl} W_2^2(\mu_0^k, \mu_1^l)$

where here $$\Pi(\pi_0, \pi_1)$$ is the subset of the simplex $$\Gamma_{K_0, K_1}$$ with marginals $$\pi_0$$ and $$\pi_1$$ and $$W^2_2(\mu_0^k, \mu_1^l)$$ is the Wasserstein distance between component $$k$$ of $$\mu_0$$ and component $$l$$ of $$\mu_1$$ (see 4.4).

We can obtain a regularized solution to this minimization problem by applying the Sinkhorn algorithm with the Bures cost function.

suggest an application of $$MW_2$$: we can approximate an optimal transport map between two point clouds by simultaneously fitting a GaussianMixture model to each PointCloud and minimizing the $$MW_2$$ distance between the fitted GMMs (see section 6). The approach scales well to large point clouds since the Sinkhorn algorithm is applied only to the mixture components rather than to individual points. The resulting couplings are easy to interpret since they involve relatively small numbers of components, and the transport maps are mixtures of piecewise linear maps.

Here we demonstrate the approach on some synthetic data.

import sys

!pip install -q git+https://github.com/ott-jax/ott@main

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

from ott.tools.gaussian_mixture import (
fit_gmm,
fit_gmm_pair,
gaussian_mixture,
gaussian_mixture_pair,
probabilities,
)

def get_cov_ellipse(mean, cov, n_sds=2, **kwargs):
"""Get a matplotlib Ellipse patch for a given mean and covariance.

"""
# Find and sort eigenvalues and eigenvectors into descending order
eigvals, eigvecs = jnp.linalg.eigh(cov)
order = eigvals.argsort()[::-1]
eigvals, eigvecs = eigvals[order], eigvecs[:, order]

# The anti-clockwise angle to rotate our ellipse by
vx, vy = eigvecs[:, 0][0], eigvecs[:, 0][1]
theta = np.arctan2(vy, vx)

# Width and height of ellipse to draw
width, height = 2 * n_sds * np.sqrt(eigvals)
return matplotlib.patches.Ellipse(
xy=mean, width=width, height=height, angle=np.degrees(theta), **kwargs
)

key = jax.random.PRNGKey(0)


## Generate synthetic data#

Construct two GaussianMixture models that we’ll use to generate some samples.

The two GMMs have small differences in their means, covariances, and in their weights.

mean_generator0 = jnp.array([[2.0, -1.0], [-2.0, 0.0], [4.0, 3.0]])
cov_generator0 = 3.0 * jnp.array(
[
[[0.2, 0.0], [0.0, 0.1]],
[[0.6, 0.0], [0.0, 0.3]],
[[0.5, -0.4], [-0.4, 0.5]],
]
)
weights_generator0 = jnp.array([0.2, 0.2, 0.6])

gmm_generator0 = (
gaussian_mixture.GaussianMixture.from_mean_cov_component_weights(
mean=mean_generator0,
cov=cov_generator0,
component_weights=weights_generator0,
)
)

def rot(m, theta):
# left multiply m by a theta degree rotation matrix
theta_rad = theta * 2.0 * np.pi / 360.0
m_rot = jnp.array(
[
]
)
return jnp.matmul(m_rot, m)

# shift the means to the right by varying amounts
mean_generator1 = mean_generator0 + jnp.array(
[[1.0, -0.5], [-1.0, -1.0], [-1.0, 0.0]]
)
# rotate the covariances a bit
cov_generator1 = jnp.stack(
[
rot(cov_generator0[0, :], 5),
rot(cov_generator0[1, :], -5),
rot(cov_generator0[2, :], -10),
],
axis=0,
)
weights_generator1 = jnp.array([0.4, 0.4, 0.2])

gmm_generator1 = (
gaussian_mixture.GaussianMixture.from_mean_cov_component_weights(
mean=mean_generator1,
cov=cov_generator1,
component_weights=weights_generator1,
)
)

N = 10_000
key, subkey0, subkey1 = jax.random.split(key, num=3)
samples_gmm0 = gmm_generator0.sample(key=subkey0, size=N)
samples_gmm1 = gmm_generator1.sample(key=subkey1, size=N)

fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(12, 6))
axes[0].scatter(samples_gmm0[:, 0], samples_gmm0[:, 1], marker=".", alpha=0.25)
axes[0].set_title("Samples from generating GMM 0")
axes[1].scatter(samples_gmm1[:, 0], samples_gmm1[:, 1], marker=".", alpha=0.25)
axes[1].set_title("Samples from generating GMM 1")
plt.show()


## Fit a pair of coupled GMMs#

# As a starting point for our optimization, we pool the two sets of samples
# and fit a single GMM to the combined samples
samples = jnp.concatenate([samples_gmm0, samples_gmm1])
key, subkey = jax.random.split(key)
gmm_init = fit_gmm.initialize(
key=subkey, points=samples, point_weights=None, n_components=3, verbose=True
)
pooled_gmm = fit_gmm.fit_model_em(
gmm=gmm_init, points=samples, point_weights=None, steps=20
)

# Now we use EM to fit a GMM to each set of samples while penalizing the
# distance between the pair of GMMs
%%time
EPSILON = 1.0e-2  # regularization weight for the Sinkhorn algorithm
WEIGHT_TRANSPORT = 0.01  # weight for the MW2 distance penalty between the GMMs
pair_init = gaussian_mixture_pair.GaussianMixturePair(
gmm0=pooled_gmm, gmm1=pooled_gmm, epsilon=EPSILON, tau=1.0
)

fit_model_em_fn = fit_gmm_pair.get_fit_model_em_fn(
weight_transport=WEIGHT_TRANSPORT, jit=True
)

pair, loss = fit_model_em_fn(
pair=pair_init,
points0=samples_gmm0,
points1=samples_gmm1,
point_weights0=None,
point_weights1=None,
em_steps=30,
m_steps=20,
verbose=True,
)

  0 -3.862 -3.877 transport:0.011 objective:-7.739
1 -3.835 -3.848 transport:0.041 objective:-7.683
2 -3.814 -3.822 transport:0.129 objective:-7.637
3 -3.794 -3.802 transport:0.262 objective:-7.598
4 -3.775 -3.781 transport:0.448 objective:-7.561
5 -3.761 -3.763 transport:0.676 objective:-7.531
6 -3.748 -3.746 transport:0.954 objective:-7.504
7 -3.737 -3.731 transport:1.259 objective:-7.481
8 -3.725 -3.717 transport:1.628 objective:-7.458
9 -3.716 -3.704 transport:2.050 objective:-7.440
10 -3.704 -3.694 transport:2.502 objective:-7.423
11 -3.694 -3.683 transport:2.773 objective:-7.405
12 -3.686 -3.675 transport:3.313 objective:-7.395
13 -3.678 -3.667 transport:3.388 objective:-7.378
14 -3.671 -3.661 transport:3.730 objective:-7.369
15 -3.664 -3.654 transport:3.836 objective:-7.357
16 -3.659 -3.652 transport:3.993 objective:-7.351
17 -3.652 -3.646 transport:4.349 objective:-7.341
18 -3.648 -3.644 transport:4.494 objective:-7.337
19 -3.643 -3.640 transport:4.606 objective:-7.329
20 -3.638 -3.638 transport:5.005 objective:-7.326
21 -3.635 -3.635 transport:5.146 objective:-7.321
22 -3.629 -3.633 transport:5.559 objective:-7.318
23 -3.626 -3.630 transport:5.715 objective:-7.314
24 -3.622 -3.628 transport:6.157 objective:-7.311
25 -3.621 -3.625 transport:6.335 objective:-7.309
26 -3.617 -3.624 transport:6.571 objective:-7.307
27 -3.616 -3.620 transport:6.952 objective:-7.306
28 -3.613 -3.620 transport:6.986 objective:-7.302
29 -3.612 -3.617 transport:7.067 objective:-7.300
CPU times: user 2min 4s, sys: 3.74 s, total: 2min 8s
Wall time: 2min 6s

colors = ["red", "green", "blue"]
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True)
for i, (gmm, samples) in enumerate(
[(pair.gmm0, samples_gmm0), (pair.gmm1, samples_gmm1)]
):
assignment_prob = gmm.get_log_component_posterior(samples)
assignment = jnp.argmax(assignment_prob, axis=-1)
for j, component in enumerate(gmm.components()):
subset = assignment == j
axes[i].scatter(
samples[subset, 0],
samples[subset, 1],
marker=".",
alpha=0.01,
color=colors[j],
label=j,
)
ellipse = get_cov_ellipse(
component.loc,
component.covariance(),
n_sds=2,
ec=colors[j],
fill=False,
lw=2,
)
legend = axes[i].legend()
for lh in legend.legendHandles:
lh.set_alpha(1)
axes[i].set_title(f"Fitted GMM {i} and samples")
plt.show()

print("Fitted GMM 0 masses", pair.gmm0.component_weights)
print("Fitted GMM 1 masses", pair.gmm1.component_weights)
print("Mass transfer, rows=source, columns=destination")
cost_matrix = pair.get_cost_matrix()
sinkhorn_output = pair.get_sinkhorn(cost_matrix=cost_matrix)
print(pair.get_normalized_sinkhorn_coupling(sinkhorn_output=sinkhorn_output))

Fitted GMM 0 masses [0.54755753 0.2253582  0.22708425]
Fitted GMM 1 masses [0.26321912 0.374353   0.36242783]
Mass transfer, rows=source, columns=destination
[[0.32590222 0.         0.22165503]
[0.         0.22536384 0.        ]
[0.         0.         0.22707897]]


## Reweighting components#

In the approach above, we can only change the weights of components by transferring mass between them. In some settings, allowing reweighting of components can lead to couplings that are easier to interpret. For example, in a biological application in which points correspond to a population of featurized representations of organisms, mixture components might capture subpopulations and a component reweighting might correspond to a prevalence change for the subpopulation.

We can generalize the approach above to allow component reweighting by using an unbalanced variant of MW2 as our measure of distance between GMMs.

Recall that

$MW_2^2(\mu_0, \mu_1) = \min_{w \in \Pi(\pi_0, \pi_1)} \sum_{k,l} w_{kl} W_2^2(\mu_0^k, \mu_1^l)$

We use the Sinkhorn algorithm to obtain a solution to a regularized version of the above minimization:

$MW_2^2(\mu_0, \mu_1) \approx \min_{w \in \Pi(\pi_0, \pi_1)} \sum_{k,l} w_{kl} W_2^2(\mu_0^k, \mu_1^l) + \epsilon KL(w, a^T b)$

## An unbalanced Wasserstein divergence for GMMs#

We define $$UW_2^2$$, an unbalanced version of $$MW_2^2$$, as follows:

$UW_2^2(\mu_0, \mu_1) = \min_{w_{k,l} \geq 0} \sum_{k,l} w_{kl} W_2^2(\mu_0^k, \mu_1^l) + \rho KL(w_{k \cdot}||\pi_0^k) + \rho KL(w_{\cdot l}||\pi_1^l)$

where $$KL(f||g)$$ is the generalized KL divergence,

$KL(f||g) = \sum_i f_i \log \frac{f_i}{g_i} - f_i + g_i$

which does not assume that either $$\sum f_i = 1$$ or $$\sum g_i = 1$$.

As above, we add a regularization term to make the problem convex and solve with the unbalanced Sinkhorn algorithm.

## Interpreting the results#

The coupling matrix $$W$$ we obtain from the unbalanced Sinkhorn algorithm has marginals that do not necessarily match the component weights of our GMMs, and it’s worth looking in detail at an example to see how we might interpret this mismatch.

### Marginal mismatch#

Suppose we have a pair of 2-component GMMs:

• $$\mu_0$$ with component weights $$0.2$$ and $$0.8$$, and

• $$\mu_1$$ with component weights $$0.4$$ and $$0.6$$.

Suppose the unbalanced Sinkhorn algorithm yields the coupling matrix

$\begin{split}W = \begin{pmatrix}0.3 & 0.1\\0.2 & 0.4 \end{pmatrix}\end{split}$

The first row of the coupling matrix $$W$$ indicates that $$0.4$$ units of mass flow out of the first component of $$\mu_0$$, $$0.3$$ units to the first component of $$\mu_1$$ and $$0.1$$ to the second component of $$\mu_1$$. However, the first component of $$\mu_0$$ only has $$0.2$$ units of mass!

Similarly, the first column of $$W$$ indicates that $$0.5$$ units of mass flow into the first component of $$\mu_1$$, $$0.3$$ from the first component of $$\mu_0$$ and $$0.2$$ from the second component of $$\mu_0$$. Again, while $$0.5$$ units of mass flow in, the first component of $$\mu_1$$ only has $$0.4$$ units of mass.

### Reweighting points#

Our interpretation is this: points from $$\mu_0$$ undergo two reweightings during transport, the first as they leave a component in $$\mu_0$$ and the second as they enter a component in $$\mu_1$$. Each of these reweightings has a cost that is reflected in the KL divergence between the marginals of the coupling matrix and the weights of the corresponding GaussianMixture components.

Suppose we transport a point with weight 1 from the first component of $$\mu_0$$ to the first component of $$\mu_1$$.

• We see from the coupling matrix that the first component of $$\mu_0$$ has mass $$0.2$$ but has an outflow of $$0.4$$. To achieve the indicated outflow, we double the weight of our point as it leaves the first component of $$\mu_0$$, so now our point has a weight of $$2$$.

• We see that the first component of $$\mu_1$$ has a mass of $$0.4$$ but an inflow of $$0.5$$. To achieve the indicated inflow, we need to decrease the weight of incoming points by a factor of $$0.8$$.

The net effect is that the weight of our point increases by a factor of $$2 \times 0.8 = 1.6$$

### Non-normalized couplings#

One point that is worth emphasizing: in the unbalanced case, the coupling matrix we obtain from the Sinkhorn algorithm need not have a total mass of $$1$$!

Let’s look at the objective function in more detail to see why this might happen.

Recall that $$UW_2^2$$ penalizes mismatches between the marginals of the coupling matrix and the GMM component weights via the generalized KL divergence,

$KL(f||g) = \sum_i f_i \log \frac{f_i}{g_i} - f_i + g_i$

In the divergence above, $$f$$ is a marginal of the coupling, which may not sum to $$1$$, and $$g$$ is the set of weights for a GMM and does sum to 1. Let $$p_i = \frac{f_i}{\sum_i f_i} = \frac{f_i}{F}$$ be the normalized marginal of the coupling. We have

$\begin{split}KL(f||g) = \sum_i F p_i \log \frac{F p_i}{g_i} - F p_i + g_i \\ = F \sum_i \left(p_i \log \frac{p_i}{g_i} + p_i \log F \right) - F + 1 \\ = F \sum_i p_i \log \frac{p_i}{g_i} + F \log F - F + 1 \\ = F KL(p||g) + (F \log F - F + 1)\end{split}$

Thus, having an non-normalized coupling scales each KL divergence penalty by the total mass of the coupling, $$F$$, and adds a penalty of the form $$F \log F - F + 1$$.

In addition, the transport cost for the non-normalized coupling is simply the transport cost for the normalized coupling scaled by the same factor $$F$$.

The result is that the cost for an non-normalized coupling $$W$$ that sums to $$F$$ is $$F$$ times the cost for the normalized coupling $$W/F$$ plus $$(\epsilon + 2\rho)(F \log F - F + 1)$$.

For $$F \geq 0$$, the function $$F \log F - F + 1$$ is strictly convex, has a minimum of $$0$$ at $$1$$ and is $$1$$ at $$0$$ and $$e$$.

# @title x log x - x + 1  { display-mode: "form" }
x = np.arange(0, 4, 0.1)
y = x * jnp.log(x) - x + 1
y = y.at[0].set(1.0)
plt.plot(x, y)
plt.title("y = x log x - x + 1")
plt.show()


We should never get an $$F$$ larger than $$1$$, since such an $$F$$ will both increase the cost of the normalized coupling as well as introduce a positive penalty term. If we use the balanced Sinkhorn algorithm, we will always have $$F = 1$$.

The case of $$F \in (0, 1)$$ can be interpreted to mean that all points are down-weighted for transport to reduce the overall cost. We can shift the transport and reweighting costs into the normalization penalty, $$(\epsilon + 2 \rho)(F \log F - F + 1)$$.

The net effect of this flexibility in allocating costs to the normalization penalty term is to bound the total regularized cost to be less than or equal to $$(\epsilon + 2 \rho)(F \log F - F + 1) <= (\epsilon + 2 \rho)$$, something to consider in setting the various weights used in the overall optimization.

%%time
# here we use a larger transport weight because the transport cost is smaller
# (see discussion above)
WEIGHT_TRANSPORT = 0.1
RHO = 1.0
TAU = RHO / (RHO + EPSILON)

# Again for our initial model, we will use a GMM fit on the pooled points
pair_init2 = gaussian_mixture_pair.GaussianMixturePair(
gmm0=pooled_gmm, gmm1=pooled_gmm, epsilon=EPSILON, tau=TAU
)

fit_model_em_fn2 = fit_gmm_pair.get_fit_model_em_fn(
weight_transport=WEIGHT_TRANSPORT, jit=True
)

pair2, loss = fit_model_em_fn2(
pair=pair_init2,
points0=samples_gmm0,
points1=samples_gmm1,
point_weights0=None,
point_weights1=None,
em_steps=30,
m_steps=20,
verbose=True,
)

  0 -3.862 -3.877 transport:0.011 objective:-7.740
1 -3.835 -3.848 transport:0.019 objective:-7.685
2 -3.814 -3.822 transport:0.042 objective:-7.640
3 -3.794 -3.801 transport:0.061 objective:-7.601
4 -3.776 -3.781 transport:0.093 objective:-7.566
5 -3.761 -3.763 transport:0.117 objective:-7.536
6 -3.749 -3.747 transport:0.152 objective:-7.511
7 -3.736 -3.731 transport:0.182 objective:-7.485
8 -3.725 -3.717 transport:0.220 objective:-7.463
9 -3.715 -3.704 transport:0.265 objective:-7.445
10 -3.704 -3.694 transport:0.296 objective:-7.427
11 -3.695 -3.682 transport:0.346 objective:-7.412
12 -3.687 -3.675 transport:0.372 objective:-7.398
13 -3.677 -3.667 transport:0.427 objective:-7.386
14 -3.671 -3.661 transport:0.454 objective:-7.378
15 -3.662 -3.655 transport:0.500 objective:-7.368
16 -3.657 -3.652 transport:0.522 objective:-7.361
17 -3.651 -3.647 transport:0.563 objective:-7.354
18 -3.646 -3.645 transport:0.566 objective:-7.348
19 -3.640 -3.642 transport:0.603 objective:-7.342
20 -3.637 -3.639 transport:0.608 objective:-7.336
21 -3.632 -3.636 transport:0.643 objective:-7.332
22 -3.629 -3.634 transport:0.639 objective:-7.327
23 -3.624 -3.632 transport:0.674 objective:-7.323
24 -3.623 -3.629 transport:0.684 objective:-7.320
25 -3.618 -3.627 transport:0.711 objective:-7.316
26 -3.617 -3.624 transport:0.731 objective:-7.314
27 -3.614 -3.625 transport:0.734 objective:-7.312
28 -3.613 -3.622 transport:0.745 objective:-7.309
29 -3.611 -3.622 transport:0.748 objective:-7.308
CPU times: user 2min 11s, sys: 4.29 s, total: 2min 15s
Wall time: 2min 14s

print("Fitted GMM 0 masses", pair2.gmm0.component_weights)
print("Fitted GMM 1 masses", pair2.gmm1.component_weights)
cost_matrix = pair2.get_cost_matrix()
sinkhorn_output = pair2.get_sinkhorn(cost_matrix=cost_matrix)
print("Normalized coupling")
print(pair2.get_normalized_sinkhorn_coupling(sinkhorn_output=sinkhorn_output))

Fitted GMM 0 masses [0.56643116 0.19865721 0.23491158]
Fitted GMM 1 masses [0.255262   0.39326635 0.35147163]
Normalized coupling
[[0.4567368  0.         0.        ]
[0.         0.25642195 0.        ]
[0.         0.         0.2868412 ]]


Notice above that neither marginal of the fitted coupling matches the corresponding GMM masses. One way to interpret the coupling is as follows:

Mass is reweighted at two points: first, as it leaves one component, and second, as it enters another.

So we see that the heaviest component above has its mass downweighted by a factor of approx. $$2$$, and the two lighter components both have their masses roughly doubling.