Source code for

# Copyright OTT-JAX
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import jax
import jax.numpy as jnp

from ott.geometry import costs, geometry, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from import gaussian_mixture

__all__ = ["GaussianMixturePair"]

[docs] @jax.tree_util.register_pytree_node_class class GaussianMixturePair: """Coupled pair of Gaussian mixture models. Includes methods used in estimating an optimal pairing between GMM components using the Wasserstein-like method described in :cite:`delon:20`, as well as generalization that allows for the reweighting of components. :cite:`delon:20` propose fitting a pair of GMMs to a pair of point clouds in such a way that the sum of the log likelihood of the points minus a weighted penalty involving a Wasserstein-like distance between the GMMs. Their proposed algorithm involves using EM in which a balanced Sinkhorn algorithm is used to estimate a coupling between the GMMs at each step of EM. Our generalization of this algorithm allows for a mismatch between the marginals of the coupling and the GMM component weights. This mismatch can be interpreted as components being reweighted rather than being transported. We penalize reweighting with a generalized KL-divergence penalty, and we give the option to use the unbalanced Sinkhorn algorithm rather than the balanced to compute the divergence between GMMs. """ def __init__( self, gmm0: gaussian_mixture.GaussianMixture, gmm1: gaussian_mixture.GaussianMixture, epsilon: float = 1e-2, tau: float = 1.0, lock_gmm1: bool = False, ): """Constructor. When fitting a pair of coupled GMMs with *no* reweighting of components using the algorithm in :cite:`delon:20`, set tau = 1. The coupling between components will be determined via the balanced Sinkhorn algorithm. When fitting a pair of coupled GMMs in which reweighting of components is allowed, set tau to a value in (0, 1). The resulting coupling will penalize the generalized KL divergence between the coupling's marginals and the GMM component weights with a weight of rho = epsilon tau / (1 - tau). Args: gmm0: first GMM in the pair gmm1: second GMM in the pair epsilon: regularization weight to use for the Sinkhorn algorithm tau: encodes the weight, rho, to use for the generalized KL divergence between the coupling's marginals and GMM component weights as rho = epsilon tau / (1 - tau) lock_gmm1: indicates whether the parameters of gmm1 should be modified during optimization """ # noqa: D401 self._gmm0 = gmm0 self._gmm1 = gmm1 self._epsilon = epsilon self._tau = tau self._lock_gmm1 = lock_gmm1 @property def dtype(self): # noqa: D102 return self.gmm0.dtype @property def gmm0(self): # noqa: D102 return self._gmm0 @property def gmm1(self): # noqa: D102 return self._gmm1 @property def epsilon(self): # noqa: D102 return self._epsilon @property def tau(self): # noqa: D102 return self._tau @property def rho(self): # noqa: D102 return self.epsilon * self.tau / (1.0 - self.tau) @property def lock_gmm1(self): # noqa: D102 return self._lock_gmm1
[docs] def get_bures_geometry(self) -> pointcloud.PointCloud: """Get a Bures Geometry for the two GMMs.""" mean0 = self.gmm0.loc dimension = mean0.shape[-1] cov0 = self.gmm0.covariance cov0 = cov0.reshape(cov0.shape[:-2] + (dimension * dimension,)) x = jnp.concatenate([mean0, cov0], axis=-1) mean1 = self.gmm1.loc cov1 = self.gmm1.covariance cov1 = cov1.reshape(cov1.shape[:-2] + (dimension * dimension,)) y = jnp.concatenate([mean1, cov1], axis=-1) return pointcloud.PointCloud( x=x, y=y, cost_fn=costs.Bures(dimension=dimension), epsilon=self.epsilon )
[docs] def get_cost_matrix(self) -> jnp.ndarray: """Get matrix of :math:`W_2^2` costs between all pairs of components.""" return self.get_bures_geometry().cost_matrix
[docs] def get_sinkhorn( self, cost_matrix: jnp.ndarray, **kwargs: Any ) -> sinkhorn.SinkhornOutput: """Get the output of Sinkhorn's method for a given cost matrix.""" # We use a Geometry here rather than the PointCloud created in # get_bures_geometry to avoid recomputing the cost matrix, since # the cost matrix is quite expensive geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=self.epsilon) prob = linear_problem.LinearProblem( geom, a=self.gmm0.component_weights, b=self.gmm1.component_weights, tau_a=self.tau, tau_b=self.tau ) return sinkhorn.Sinkhorn(**kwargs)(prob)
[docs] def get_normalized_sinkhorn_coupling( self, sinkhorn_output: sinkhorn.SinkhornOutput, ) -> jnp.ndarray: """Get the normalized coupling matrix for the specified Sinkhorn output. Args: sinkhorn_output: Sinkhorn algorithm output as returned by :meth:`get_sinkhorn`. Returns: A coupling matrix that tells how much of the mass of each component of :attr:`gmm0` is mapped to each component of :attr:`gmm1`. """ return sinkhorn_output.matrix / jnp.sum(sinkhorn_output.matrix)
def tree_flatten(self): """Method used by jax.tree_util to flatten a GaussianMixturePair. We control the subset of parameters that we will optimize in fit_gmm_pair by selectively placing them in either children (the parameters to optimize) or aux_data (the parameters to leave alone). Returns: A tuple of child pytrees and a dict of auxiliary data. """ # noqa: D401 children = [self.gmm0] aux_data = { "epsilon": self.epsilon, "tau": self.tau, "lock_gmm1": self.lock_gmm1 } if self.lock_gmm1: aux_data["gmm1"] = self.gmm1 else: children.append(self.gmm1) return tuple(children), aux_data @classmethod def tree_unflatten(cls, aux_data, children): """Method used by jax.tree_util to unflatten a GaussianMixturePair. tree_flatten controls which parameters get optimized by placing them in either children or aux_data; here we invert the process. Args: aux_data: auxiliary data that is passed to the constructor as kwargs children: child pytrees passed to the constructor as args Returns: A GaussianMixturePair. """ # noqa: D401 children = list(children) if "gmm1" in aux_data: gmm1 = aux_data.pop("gmm1") children.insert(1, gmm1) return cls(*children, **aux_data) def __repr__(self): class_name = type(self).__name__ children, aux = self.tree_flatten() return "{}({})".format( class_name, ", ".join([repr(c) for c in children] + [f"{k}: {repr(v)}" for k, v in aux.items()]) ) def __hash__(self): return jax.tree_util.tree_flatten(self).__hash__() def __eq__(self, other): return jax.tree_util.tree_flatten(self) == jax.tree_util.tree_flatten(other)