ott.tools.gaussian_mixture.fit_gmm_pair.get_fit_model_em_fn

ott.tools.gaussian_mixture.fit_gmm_pair.get_fit_model_em_fn#

ott.tools.gaussian_mixture.fit_gmm_pair.get_fit_model_em_fn(weight_transport, learning_rate=0.001, jit=True)[source]#

Get a function that performs penalized EM.

We precompile and precompute a few quantities that we put into a closure.

Parameters:
  • weight_transport (float) – weight for the transportation loss in the total loss

  • learning_rate (float) – learning rate to use for the Adam optimizer

  • jit (bool) – if True, precompile key methods

Returns:

A function that performs generalized, penalized EM.