ott.tools.gaussian_mixture.fit_gmm.fit_model_em

Contents

ott.tools.gaussian_mixture.fit_gmm.fit_model_em#

ott.tools.gaussian_mixture.fit_gmm.fit_model_em(gmm, points, point_weights, steps, jit=True, verbose=False)[source]#

Fit a GMM using the EM algorithm.

Parameters:
  • gmm (GaussianMixture) – initial GMM model

  • points (Array) – set of samples to fit, shape (n, n_dimensions)

  • point_weights (Optional[Array]) – optional set of weights for points, shape (n,). If None, uses equal weights for all points.

  • steps (int) – number of steps of EM to perform

  • jit (bool) – if True, compile functions

  • verbose (bool) – if True, print the loss at each step

Return type:

GaussianMixture

Returns:

A GMM with updated parameters.