ott.neural.methods.flow_matching.flow_matching_step#
- ott.neural.methods.flow_matching.flow_matching_step(model, optimizer, batch, *, loss_fn=<function squared_error>, model_callback_fn=None, rngs=None)[source]#
Perform a flow matching step.
- Parameters:
model (
Module) – Velocity field with a signature(t, x_t, cond, rngs=...) -> v_t.optimizer (
Optimizer) – Optimizer.batch (
Dict[Literal['t','x_t','v_t','cond'],Array]) –Batch containing the following elements:
't'- time, array of shape[batch,].'x_t'- position, array of shape[batch, ...].'v_t'- target velocity, array of shape[batch, ...].'cond'- condition (optional), array of shape[batch, ...].
loss_fn (
Callable[[Array,Array],Array]) – Loss function with a signature(pred, target) -> loss.model_callback_fn (
Optional[Callable[[Module],None]]) – Function with a signature(model) -> None, e.g., to update anEMAof the model.rngs (
Optional[Rngs]) – Random number generator used for, e.g., dropout, passed to the model.
- Return type:
- Returns:
Updates the parameters in-place and returns the loss and the gradient norm.