ott.neural.methods.flow_matching.curvature

Contents

ott.neural.methods.flow_matching.curvature#

ott.neural.methods.flow_matching.curvature(model, x0, cond=None, *, ts, drop_last_velocity=None, loss_fn=<function squared_error>, **kwargs)[source]#

Compute the curvature [Lee et al., 2023].

Also known as straightness in [Liu et al., 2022].

Parameters:
  • model (Module) – Velocity field with a signature (t, x_t, cond) -> v_t.

  • x0 (Array) – Initial point of shape [*dims].

  • cond (Optional[Array]) – Condition of shape [*cond_dims].

  • ts (Union[int, Array, Sequence[float]]) – Time points at which velocities are computed and stored. If int, use linearly-spaced interval [t0, t1] with ts steps.

  • drop_last_velocity (Optional[bool]) – Whether to remove the velocity at ts[-1]. when computing the curvature. If None, don’t include it when ts[-1] == 1.0.

  • loss_fn (Callable[[Array, Array], Array]) – Loss function with a signature (pred, target) -> loss.

  • kwargs (Any) – Keyword arguments for evaluate_velocity_field().

Return type:

Tuple[Array, Solution]

Returns:

The curvature and the ODE solution.