ott.neural.methods.flow_matching.gaussian_nll#
- ott.neural.methods.flow_matching.gaussian_nll(model, x1, cond=None, *, noise=None, stddev=1.0, **kwargs)[source]#
Compute the Gaussian negative log-likelihood.
- Parameters:
model (
Module) – Velocity model with a signature(t, x_t, cond) -> v_t.x1 (
Array) – Initial point of shape[*dims].noise (
Optional[Array]) – Array of shape[num_noise_samples, ...]used for the Hutchinson’s trace estimate of the divergence of the velocity field. IfNone, compute the exact divergence usingjax.jacrev().stddev (
float) – Standard deviation of the Gaussian distribution.kwargs (
Any) – Keyword arguments forevaluate_velocity_field().
- Return type:
- Returns:
The Gaussian negative log-likelihood in bits-per-dimension.