ott.neural.methods.flow_matching.gaussian_nll

Contents

ott.neural.methods.flow_matching.gaussian_nll#

ott.neural.methods.flow_matching.gaussian_nll(model, x1, cond=None, *, noise=None, stddev=1.0, **kwargs)[source]#

Compute the Gaussian negative log-likelihood.

Parameters:
  • model (Module) – Velocity model with a signature (t, x_t, cond) -> v_t.

  • x1 (Array) – Initial point of shape [*dims].

  • cond (Optional[Array]) – Condition [*cond_dims].

  • noise (Optional[Array]) – Array of shape [num_noise_samples, ...] used for the Hutchinson’s trace estimate of the divergence of the velocity field. If None, compute the exact divergence using jax.jacrev().

  • stddev (float) – Standard deviation of the Gaussian distribution.

  • kwargs (Any) – Keyword arguments for evaluate_velocity_field().

Return type:

Tuple[Array, Solution]

Returns:

The Gaussian negative log-likelihood in bits-per-dimension.