ott.neural.methods.flow_matching.curvature#
- ott.neural.methods.flow_matching.curvature(model, x0, cond=None, *, ts, drop_last_velocity=None, loss_fn=<function squared_error>, **kwargs)[source]#
Compute the curvature [Lee et al., 2023].
Also known as straightness in [Liu et al., 2022].
- Parameters:
model (
Module) – Velocity field with a signature(t, x_t, cond) -> v_t.x0 (
Array) – Initial point of shape[*dims].ts (
Union[int,Array,Sequence[float]]) – Time points at which velocities are computed and stored. Ifint, use linearly-spaced interval[t0, t1]withtssteps.drop_last_velocity (
Optional[bool]) – Whether to remove the velocity atts[-1]. when computing the curvature. IfNone, don’t include it whents[-1] == 1.0.loss_fn (
Callable[[Array,Array],Array]) – Loss function with a signature(pred, target) -> loss.kwargs (
Any) – Keyword arguments forevaluate_velocity_field().
- Return type:
- Returns:
The curvature and the ODE solution.