ott.neural.methods.flows.genot.GENOT

Contents

ott.neural.methods.flows.genot.GENOT#

class ott.neural.methods.flows.genot.GENOT(vf, flow, data_match_fn, *, source_dim, target_dim, condition_dim=None, time_sampler=<function uniform_sampler>, latent_noise_fn=None, latent_match_fn=None, n_samples_per_src=1, **kwargs)[source]#

Generative Entropic Neural Optimal Transport [Klein et al., 2023].

GENOT is a framework for learning neural optimal transport plans between two distributions. It allows for learning linear and quadratic (Fused) Gromov-Wasserstein couplings, in both the balanced and the unbalanced setting.

Parameters:
  • vf (VelocityField) – Vector field parameterized by a neural network.

  • flow (BaseFlow) – Flow between the latent and the target distributions.

  • data_match_fn (Union[Callable[[Tuple[Array, Array]], Array], Callable[[Tuple[Array, Array, Optional[Array], Optional[Array]]], Array]]) –

    Function to match samples from the source and the target distributions. Depending on the data passed in __call__(), it has the following signature:

    • (src_lin, tgt_lin) -> matching - linear matching.

    • (src_quad, tgt_quad, src_lin, tgt_lin) -> matching - quadratic (fused) GW matching. In the pure GW setting, both src_lin and tgt_lin will be set to None.

  • source_dim (int) – Dimensionality of the source distribution.

  • target_dim (int) – Dimensionality of the target distribution.

  • condition_dim (Optional[int]) – Dimension of the conditions. If None, the underlying velocity field has no conditions.

  • time_sampler (Callable[[Array, int], Array]) – Time sampler with a (rng, n_samples) -> time signature.

  • latent_noise_fn (Optional[Callable[[Array, Tuple[int, ...]], Array]]) – Function to sample from the latent distribution in the target space with a (rng, shape) -> noise signature. If None, multivariate normal distribution is used.

  • latent_match_fn (Optional[Callable[[Array, Array], Array]]) – Function to match samples from the latent distribution and the samples from the conditional distribution with a (latent, samples) -> matching signature. If None, no matching is performed.

  • n_samples_per_src (int) – Number of samples drawn from the conditional distribution per one source sample.

  • kwargs (Any) – Keyword arguments for create_train_state().

Methods

transport(source[, condition, t0, t1, rng])

Transport data with the learned plan.