ott.neural.data.ot_dataloader.LinearOTDataloader

ott.neural.data.ot_dataloader.LinearOTDataloader#

class ott.neural.data.ot_dataloader.LinearOTDataloader(rng, dataset, epsilon=None, relative_epsilon=None, cost_fn=None, threshold=0.001, max_iterations=2000, replace=True, shardings=None)[source]#

Linear OT dataloader.

This dataloader wraps a dataloader that generates (source, target) arrays with shape [batch, ...] and aligns them using the Sinkhorn algorithm.

Parameters:
  • rng (Array) – Random number generator.

  • dataset (Iterable[Tuple[Array, Array]]) – Iterable dataset which yields a tuple of source and target arrays of shape [batch, ...].

  • epsilon (Optional[float]) – Epsilon regularization. See Geometry for more information.

  • relative_epsilon (Optional[Literal['mean', 'std']]) – Whether epsilon refers to a fraction of the mean_cost_matrix or std_cost_matrix.

  • cost_fn (Optional[CostFn]) – Cost function between two points.

  • threshold (float) – Convergence threshold for Sinkhorn.

  • max_iterations (int) – Maximum number of Sinkhorn iterations.

  • replace (bool) – Whether to sample with replacement.

  • shardings (Optional[Sharding]) – Input and output shardings for the source and target arrays.

Methods

Attributes