ott.neural.data.semidiscrete_dataloader.SemidiscreteDataloader#
- class ott.neural.data.semidiscrete_dataloader.SemidiscreteDataloader(rng, sd_out, batch_size, epsilon=None, subset_size_threshold=None, subset_size=None, return_indices=False, out_shardings=None)[source]#
Semidiscrete dataloader.
This dataloader samples from the continuous source distribution and couples it with the discrete target distribution. It returns aligned tuples of
(source, target)arrays of shape[batch, ...].- Parameters:
rng (
Array) – Random number seed used for sampling from the source distribution.sd_out (
SemidiscreteOutput) – Semidiscrete output object storing a precomputed OT solution between the (continuous) source distribution and the dataset of interest.batch_size (
int) – Batch size.epsilon (
Optional[float]) – Epsilon regularization. IfNone, use the one stored in thegeometry.epsilon – Epsilon regularization. In the context of this class, this epsilon value can be interpreted exclusively as a softmax temperature. If
None, use the one stored in thegeometrywhich was used to compute the potential stored in thesd_out.subset_size_threshold (
Optional[int]) – Threshold above which to sample from a subset of the coupling matrix. Only applicable when the problem isentropically regularized. IfNone, don’t subset the coupling matrix.subset_size (
Optional[int]) – Size of the subset of the coupling matrix. This will subset a coupling of shape[batch, m]to[batch, subset_size]using thetop_k()values ifm > subset_size_threshold.return_indices (
bool) – Whether to return, in addition to paired source and target data points, the indices corresponding to the selected target data points.out_shardings (
Optional[Sharding]) – Output shardings for the aligned batch.
Methods
Attributes