ott.neural.methods.flow_matching.flow_matching_step

ott.neural.methods.flow_matching.flow_matching_step#

ott.neural.methods.flow_matching.flow_matching_step(model, optimizer, batch, *, loss_fn=<function squared_error>, model_callback_fn=None, rngs=None)[source]#

Perform a flow matching step.

Parameters:
  • model (Module) – Velocity field with a signature (t, x_t, cond, rngs=...) -> v_t.

  • optimizer (Optimizer) – Optimizer.

  • batch (Dict[Literal['t', 'x_t', 'v_t', 'cond'], Array]) –

    Batch containing the following elements:

    • 't' - time, array of shape [batch,].

    • 'x_t' - position, array of shape [batch, ...].

    • 'v_t' - target velocity, array of shape [batch, ...].

    • 'cond' - condition (optional), array of shape [batch, ...].

  • loss_fn (Callable[[Array, Array], Array]) – Loss function with a signature (pred, target) -> loss.

  • model_callback_fn (Optional[Callable[[Module], None]]) – Function with a signature (model) -> None, e.g., to update an EMA of the model.

  • rngs (Optional[Rngs]) – Random number generator used for, e.g., dropout, passed to the model.

Return type:

Dict[Literal['loss', 'grad_norm'], Array]

Returns:

Updates the parameters in-place and returns the loss and the gradient norm.