ott.problems.nn.dataset.create_gaussian_mixture_samplers#

ott.problems.nn.dataset.create_gaussian_mixture_samplers(name_source, name_target, train_batch_size=2048, valid_batch_size=2048, rng=None)[source]#

Gaussian samplers for W2NeuralDual.

Parameters:
  • name_source (Literal['simple', 'circle', 'square_five', 'square_four']) – name of the source sampler

  • name_target (Literal['simple', 'circle', 'square_five', 'square_four']) – name of the target sampler

  • train_batch_size (int) – the training batch size

  • valid_batch_size (int) – the validation batch size

  • rng (Optional[PRNGKeyArray]) – initial PRNG key

Return type:

Tuple[Dataset, Dataset, int]

Returns:

The dataset and dimension of the data.