ott.neural.methods.flow_matching.interpolate_samples

ott.neural.methods.flow_matching.interpolate_samples#

ott.neural.methods.flow_matching.interpolate_samples(rng, x0, x1, cond=None, *, time_sampler=None)[source]#

Sample time and interpolate.

Parameters:
  • rng (Array) – Random number generator.

  • x0 (Array) – Source samples at \(t_0\), array of shape [batch, ...].

  • x1 (Array) – Target samples at \(t_1\), array of shape [batch, ...].

  • cond (Optional[Array]) – Condition.

  • time_sampler (Optional[Callable[[Array, Tuple[int], dtype], Array]]) – Time sampler with signature (rng, shape, dtype) -> time.

Returns:

  • 't' - time, array of shape [batch,].

  • 'x_t' - position \(x_t\), array of shape [batch, ...].

  • 'v_t' - target velocity \(x_1 - x_0\), array of shape [batch, ...].

  • 'cond' - condition (optional), array of shape [batch, ...].

Return type:

Dict[Literal['t', 'x_t', 'v_t', 'cond'], Array]