ott.neural.methods.flows.genot.GENOT#
- class ott.neural.methods.flows.genot.GENOT(vf, flow, data_match_fn, *, source_dim, target_dim, condition_dim=None, time_sampler=<function uniform_sampler>, latent_noise_fn=None, latent_match_fn=None, n_samples_per_src=1, **kwargs)[source]#
Generative Entropic Neural Optimal Transport [Klein et al., 2023].
GENOT is a framework for learning neural optimal transport plans between two distributions. It allows for learning linear and quadratic (Fused) Gromov-Wasserstein couplings, in both the balanced and the unbalanced setting.
- Parameters:
vf (
VelocityField
) – Vector field parameterized by a neural network.flow (
BaseFlow
) – Flow between the latent and the target distributions.data_match_fn (
Union
[Callable
[[Tuple
[Array
,Array
]],Array
],Callable
[[Tuple
[Array
,Array
,Optional
[Array
],Optional
[Array
]]],Array
]]) –Function to match samples from the source and the target distributions. Depending on the data passed in
__call__()
, it has the following signature:(src_lin, tgt_lin) -> matching
- linear matching.(src_quad, tgt_quad, src_lin, tgt_lin) -> matching
- quadratic (fused) GW matching. In the pure GW setting, bothsrc_lin
andtgt_lin
will be set toNone
.
source_dim (
int
) – Dimensionality of the source distribution.target_dim (
int
) – Dimensionality of the target distribution.condition_dim (
Optional
[int
]) – Dimension of the conditions. IfNone
, the underlying velocity field has no conditions.time_sampler (
Callable
[[Array
,int
],Array
]) – Time sampler with a(rng, n_samples) -> time
signature.latent_noise_fn (
Optional
[Callable
[[Array
,Tuple
[int
,...
]],Array
]]) – Function to sample from the latent distribution in the target space with a(rng, shape) -> noise
signature. IfNone
, multivariate normal distribution is used.latent_match_fn (
Optional
[Callable
[[Array
,Array
],Array
]]) – Function to match samples from the latent distribution and the samples from the conditional distribution with a(latent, samples) -> matching
signature. IfNone
, no matching is performed.n_samples_per_src (
int
) – Number of samples drawn from the conditional distribution per one source sample.kwargs (
Any
) – Keyword arguments forcreate_train_state()
.
Methods
transport
(source[, condition, t0, t1, rng])Transport data with the learned plan.