ott.neural.data.ot_dataloader.LinearOTDataloader#
- class ott.neural.data.ot_dataloader.LinearOTDataloader(rng, dataset, epsilon=None, relative_epsilon=None, cost_fn=None, threshold=0.001, max_iterations=2000, replace=True, shardings=None)[source]#
Linear OT dataloader.
This dataloader wraps a dataloader that generates
(source, target)arrays with shape[batch, ...]and aligns them using theSinkhornalgorithm.- Parameters:
rng (
Array) – Random number generator.dataset (
Iterable[Tuple[Array,Array]]) – Iterable dataset which yields a tuple of source and target arrays of shape[batch, ...].epsilon (
Optional[float]) – Epsilon regularization. SeeGeometryfor more information.relative_epsilon (
Optional[Literal['mean','std']]) – Whetherepsilonrefers to a fraction of themean_cost_matrixorstd_cost_matrix.cost_fn (
Optional[CostFn]) – Cost function between two points.max_iterations (
int) – Maximum number of Sinkhorn iterations.replace (
bool) – Whether to sample with replacement.shardings (
Optional[Sharding]) – Input and output shardings for the source and target arrays.
Methods
Attributes