Source code for ott.tools.gaussian_mixture.fit_gmm_pair

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Fit 2 GMMs to 2 point clouds using likelihood and (approx) W2 distance.

Suppose we have two large point clouds and want to estimate a coupling and a
W2 distance between them. :cite:`delon:20` propose fitting a GMM to each
point cloud while simultaneously minimizing a Wasserstein-like distance
called MW2 between the fitted GMMs. MW2 is an upper bound on W2,
the Wasserstein distance between the GMMs. Here we implement
their algorithm as well as a generalization that allows for reweightings using
generalized, penalized expectation-maximization
(see section 6.2 of :cite:`delon:20`).

As in `fit_gmm.py`, we assume that the observations $X_0$ and $X_1$ from
batches 0 and 1 are generated by GMMs with parameters $\Theta_0$ and $\Theta_1$,
respectively. We will use $\Theta$ to denote the combined parameters
for the two GMMs. We denote the (unobserved) components that gave rise to the
observations $X_i$ as $Z_i$.

Our goal is to maximize a weighted sum of the likelihood of the observations $X$
under the fitted GMMs and a measure of distance, $MW_2$, between the fitted
GMMs. The problem would be a straightforward maximization exercise if we knew
the components $Z$ that generated each observation $X$. Because the $Z$ are
unobserved, however, we use EM:

We start with an initial estimate of $\Theta$, $\Theta^{(t)}$.

* The E-step: We use the current $\Theta^{(t)}$ to estimate the likelihood of
  all possible cluster attributions for each observation $X$.

* The M-step: We form the function $Q(\Theta|\Theta^{(t)})$,
  the log likelihood of our observations averaged over all possible
  assignments. We then obtain an updated parameter estimate, $\Theta^{(t+1)}$,
  by numerically maximizing the sum of $Q$ and our GMM distance penalty.

It can be shown that if we maximize the penalized $Q$ above, this procedure will
increase or leave unchanged the penalized log likelihood for $\Theta$. We
iterate over these two steps until convergence. Note that the resulting
estimate for $\Theta$ may only be a *local* maximum of the penalized
likelihood function.


Sample usage:

# (Note that we usually initialize a pair to a single GMM that we fit to a
# pooled set, then the two GMMs separate as we optimize the pair.)
pair_init = gaussian_mixture_pair.GaussianMixturePair(
    gmm0=gmm0,
    gmm1=gmm1,
    epsilon=1.e-2,
    tau=1.)
fit_model_em_fn = fit_gmm_pair.get_fit_model_em_fn(
    weight_transport=0.1,
    weight_splitting=1.,
    epsilon=pair_init.epsilon,
    jit=True)
pair, loss = fit_model_em_fn(
    pair=pair_init,
    points0=samples_gmm0,
    points1=samples_gmm1,
    point_weights0=None,
    point_weights1=None,
    em_steps=30,
    m_steps=20,
    verbose=True)
"""
# TODO(geoffd): look into refactoring so we jit higher level functions

import functools
import math
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp

from ott.tools.gaussian_mixture import (
    fit_gmm,
    gaussian_mixture,
    gaussian_mixture_pair,
)

__all__ = ["get_fit_model_em_fn"]

LOG2 = math.log(2)


class Observations(NamedTuple):
  """Weighted observations and their E-step assignment probabilities."""

  points: jnp.ndarray
  point_weights: jnp.ndarray
  assignment_probs: jnp.ndarray


# Model fit


def get_q(
    gmm: gaussian_mixture.GaussianMixture, obs: Observations
) -> jnp.ndarray:
  r"""Get Q(\Theta|\Theta^{(t)}).

  Here Q is the log likelihood for our observations based on the current
  parameter estimates for \Theta and averaged over the current component
  assignment probabilities. See the overview of EM above for more details.

  Args:
    gmm: GMM model parameterized by Theta
    obs: weighted observations with component assignments computed in the E step
      for \Theta^{(t)}

  Returns:
    Q(\Theta|\Theta^{(t)})
  """
  # Q = E_Z log p(X, Z| Theta)
  # = \sum_Z P(Z|X, Theta^(t)) [log p(X, Z | Theta)]
  # Here P(Z|X, theta^(t)) is the set of assignment probabilities
  # we computed in the E step.
  # log p(X, Z| theta) is given by
  log_p_x_z = (
      gmm.conditional_log_prob(obs.points) +  # p(X | Z, theta)
      gmm.log_component_weights()
  )  # p(Z | theta)
  return (
      jnp.sum(
          obs.point_weights *
          jnp.sum(log_p_x_z * obs.assignment_probs, axis=-1),
          axis=0
      ) / jnp.sum(obs.point_weights, axis=0)
  )


# Objective function


@functools.lru_cache
def get_objective_fn(weight_transport: float):
  """Get the total loss function with static parameters in a closure.

  Args:
    weight_transport: weight for the transport penalty

  Returns:
    A function that returns the objective for a GaussianMixturePair.
  """

  def _objective_fn(
      pair: gaussian_mixture_pair.GaussianMixturePair,
      obs0: Observations,
      obs1: Observations,
  ) -> jnp.ndarray:
    """Compute the objective function for a pair of GMMs.

    Args:
      pair: pair of GMMs + coupling for which to evaluate the objective
      obs0: first set of observations
      obs1: second set of observations

    Returns:
      The objective to be minimized in the M-step.
    """
    q0 = get_q(gmm=pair.gmm0, obs=obs0)
    q1 = get_q(gmm=pair.gmm1, obs=obs1)
    cost_matrix = pair.get_cost_matrix()
    sinkhorn_output = pair.get_sinkhorn(cost_matrix=cost_matrix)
    transport_penalty = sinkhorn_output.reg_ot_cost
    return q0 + q1 - weight_transport * transport_penalty

  return _objective_fn


def print_losses(
    iteration: int, weight_transport: float,
    pair: gaussian_mixture_pair.GaussianMixturePair, obs0: Observations,
    obs1: Observations
):
  """Print the loss components for diagnostic purposes."""
  q0 = get_q(gmm=pair.gmm0, obs=obs0)
  q1 = get_q(gmm=pair.gmm1, obs=obs1)
  cost_matrix = pair.get_cost_matrix()
  sinkhorn_output = pair.get_sinkhorn(cost_matrix=cost_matrix)
  transport_penalty = sinkhorn_output.reg_ot_cost
  objective = q0 + q1 - weight_transport * transport_penalty

  print(  # noqa: T201
      f"{iteration:3d} {q0:.3f} {q1:.3f} "
      f"transport:{transport_penalty:.3f} "
      f"objective:{objective:.3f}"
  )


# The E-step for a single GMM


def do_e_step(  # noqa: D103
    e_step_fn: Callable[[gaussian_mixture.GaussianMixture, jnp.ndarray],
                        jnp.ndarray],
    gmm: gaussian_mixture.GaussianMixture,
    points: jnp.ndarray,
    point_weights: jnp.ndarray,
) -> Observations:
  assignment_probs = e_step_fn(gmm, points)
  return Observations(
      points=points,
      point_weights=point_weights,
      assignment_probs=assignment_probs
  )


# The M-step


def get_m_step_fn(learning_rate: float, objective_fn, jit: bool):
  """Get a function that performs the M-step of the EM algorithm.

  We precompile and precompute a few quantities that we put into a closure.

  Args:
    learning_rate: learning rate to use for the Adam optimizer
    objective_fn: the objective function to maximize
    jit: if True, precompile key methods

  Returns:
    A function that performs the M-step of EM.
  """
  import optax

  def _m_step_fn(
      pair: gaussian_mixture_pair.GaussianMixturePair,
      obs0: Observations,
      obs1: Observations,
      steps: int,
  ) -> gaussian_mixture_pair.GaussianMixturePair:
    """Perform the M-step on a pair of Gaussian mixtures.

    Args:
      pair: GMM parameters to optimize
      obs0: first set of observations
      obs1: second set of observations
      steps: number of optimization steps to use when maximizing the objective

    Returns:
      A GaussianMixturePair with updated parameters.
    """
    state = opt_init((pair,))

    for _ in range(steps):
      grad_objective = grad_objective_fn(pair, obs0, obs1)
      updates, state = opt_update(grad_objective, state, (pair,))
      (pair,) = optax.apply_updates((pair,), updates)
      for j, gmm in enumerate((pair.gmm0, pair.gmm1)):
        if gmm.has_nans():
          raise ValueError(f"NaN in gmm{j}")
    return pair

  grad_objective_fn = jax.grad(objective_fn, argnums=(0,))
  if jit:
    grad_objective_fn = jax.jit(grad_objective_fn)

  opt_init, opt_update = optax.chain(
      # Set the parameters of Adam. Note the learning_rate is not here.
      optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
      optax.scale(learning_rate)
  )

  return _m_step_fn


[docs] def get_fit_model_em_fn( weight_transport: float, learning_rate: float = 0.001, jit: bool = True, ): """Get a function that performs penalized EM. We precompile and precompute a few quantities that we put into a closure. Args: weight_transport: weight for the transportation loss in the total loss learning_rate: learning rate to use for the Adam optimizer jit: if True, precompile key methods Returns: A function that performs generalized, penalized EM. """ objective_fn = get_objective_fn(weight_transport=weight_transport) e_step_fn = fit_gmm.get_assignment_probs if jit: objective_fn = jax.jit(objective_fn) e_step_fn = jax.jit(e_step_fn) m_step_fn = get_m_step_fn( learning_rate=learning_rate, objective_fn=objective_fn, jit=jit ) def _fit_model_em( pair: gaussian_mixture_pair.GaussianMixturePair, points0: jnp.ndarray, points1: jnp.ndarray, point_weights0: Optional[jnp.ndarray], point_weights1: Optional[jnp.ndarray], em_steps: int, m_steps: int = 50, verbose: bool = False, ) -> Tuple[gaussian_mixture_pair.GaussianMixturePair, float]: """Optimize a GaussianMixturePair using penalized EM. Args: pair: GaussianMixturePair to optimize points0: observations associated with pair.gmm0 points1: observations associated with pair.gmm1 point_weights0: weights for points0 point_weights1: weights for points1 em_steps: number of EM steps to perform m_steps: number of gradient descent steps to perform in the M-step verbose: if True, print status messages Returns: An updated GaussianMixturePair and the final loss. """ if point_weights0 is None: point_weights0 = jnp.ones(points0.shape[0]) if point_weights1 is None: point_weights1 = jnp.ones(points1.shape[0]) if pair.lock_gmm1: obs1 = do_e_step( e_step_fn=e_step_fn, gmm=pair.gmm1, points=points1, point_weights=point_weights1 ) for i in range(em_steps): # E-step obs0 = do_e_step( e_step_fn=e_step_fn, gmm=pair.gmm0, points=points0, point_weights=point_weights0 ) if not pair.lock_gmm1: obs1 = do_e_step( e_step_fn=e_step_fn, gmm=pair.gmm1, points=points1, point_weights=point_weights1 ) # print current losses if verbose: print_losses( iteration=i, weight_transport=weight_transport, pair=pair, obs0=obs0, obs1=obs1 ) # the M-step pair = m_step_fn(pair=pair, obs0=obs0, obs1=obs1, steps=m_steps) # final E-step before computing the loss obs0 = do_e_step( e_step_fn=e_step_fn, gmm=pair.gmm0, points=points0, point_weights=point_weights0 ) if not pair.lock_gmm1: obs1 = do_e_step( e_step_fn=e_step_fn, gmm=pair.gmm1, points=points1, point_weights=point_weights1 ) loss = objective_fn(pair=pair, obs0=obs0, obs1=obs1) return pair, loss return _fit_model_em