ott.neural.methods.flows.otfm.OTFlowMatching#
- class ott.neural.methods.flows.otfm.OTFlowMatching(vf, flow, match_fn=None, time_sampler=<function uniform_sampler>, **kwargs)[source]#
(Optimal transport) flow matching [Lipman et al., 2022].
With an extension to OT-FM [Pooladian et al., 2023, Tong et al., 2023].
- Parameters:
vf (
VelocityField
) – Vector field parameterized by a neural network.flow (
BaseFlow
) – Flow between the source and the target distributions.match_fn (
Optional
[Callable
[[Array
,Array
],Array
]]) – Function to match samples from the source and the target distributions. It has a(src, tgt) -> matching
signature.time_sampler (
Callable
[[Array
,int
],Array
]) – Time sampler with a(rng, n_samples) -> time
signature.kwargs (
Any
) – Keyword arguments forcreate_train_state()
.
Methods
transport
(x[, condition, t0, t1])Transport data with the learned map.