ott.neural.methods.flows.otfm.OTFlowMatching

Contents

ott.neural.methods.flows.otfm.OTFlowMatching#

class ott.neural.methods.flows.otfm.OTFlowMatching(vf, flow, match_fn=None, time_sampler=<function uniform_sampler>, **kwargs)[source]#

(Optimal transport) flow matching [Lipman et al., 2022].

With an extension to OT-FM [Pooladian et al., 2023, Tong et al., 2023].

Parameters:
  • vf (VelocityField) – Vector field parameterized by a neural network.

  • flow (BaseFlow) – Flow between the source and the target distributions.

  • match_fn (Optional[Callable[[Array, Array], Array]]) – Function to match samples from the source and the target distributions. It has a (src, tgt) -> matching signature.

  • time_sampler (Callable[[Array, int], Array]) – Time sampler with a (rng, n_samples) -> time signature.

  • kwargs (Any) – Keyword arguments for create_train_state().

Methods

transport(x[, condition, t0, t1])

Transport data with the learned map.