ott.neural.data.semidiscrete_dataloader.SemidiscreteDataloader

ott.neural.data.semidiscrete_dataloader.SemidiscreteDataloader#

class ott.neural.data.semidiscrete_dataloader.SemidiscreteDataloader(rng, sd_out, batch_size, epsilon=None, subset_size_threshold=None, subset_size=None, return_indices=False, out_shardings=None)[source]#

Semidiscrete dataloader.

This dataloader samples from the continuous source distribution and couples it with the discrete target distribution. It returns aligned tuples of (source, target) arrays of shape [batch, ...].

Parameters:
  • rng (Array) – Random number seed used for sampling from the source distribution.

  • sd_out (SemidiscreteOutput) – Semidiscrete output object storing a precomputed OT solution between the (continuous) source distribution and the dataset of interest.

  • batch_size (int) – Batch size.

  • epsilon (Optional[float]) – Epsilon regularization. If None, use the one stored in the geometry.

  • epsilon – Epsilon regularization. In the context of this class, this epsilon value can be interpreted exclusively as a softmax temperature. If None, use the one stored in the geometry which was used to compute the potential stored in the sd_out.

  • subset_size_threshold (Optional[int]) – Threshold above which to sample from a subset of the coupling matrix. Only applicable when the problem is entropically regularized. If None, don’t subset the coupling matrix.

  • subset_size (Optional[int]) – Size of the subset of the coupling matrix. This will subset a coupling of shape [batch, m] to [batch, subset_size] using the top_k() values if m > subset_size_threshold.

  • return_indices (bool) – Whether to return, in addition to paired source and target data points, the indices corresponding to the selected target data points.

  • out_shardings (Optional[Sharding]) – Output shardings for the aligned batch.

Methods

Attributes