ott.solvers.linear.semidiscrete.SemidiscreteSolver

ott.solvers.linear.semidiscrete.SemidiscreteSolver#

class ott.solvers.linear.semidiscrete.SemidiscreteSolver(*, num_iterations, batch_size, optimizer, error_eval_every=1000, error_batch_size=None, error_num_repeats=16, threshold=0.001, potential_ema=0.99, epsilon_scheduler=<function constant_epsilon_scheduler>, callback=None)[source]#

Semidiscrete optimal transport solver.

Parameters:
  • num_iterations (int) – Number of iterations.

  • batch_size (int) – Number of points to sample at each iteration.

  • optimizer (GradientTransformation) – Optimizer.

  • error_eval_every (int) – Compute the chi-squared error every error_eval_every iterations.

  • error_batch_size (Optional[int]) – Batch size to use when computing the marginal chi-squared error. If None, use batch_size.

  • error_num_repeats (int) – Number of repeats used to estimate the marginal chi-squared error, set to sixteen by default.

  • threshold (float) – Convergence threshold for the marginal chi-squared error.

  • potential_ema (float) – Exponential moving average of the dual potential.

  • epsilon_scheduler (Callable[[Array, Array], Array]) – Epsilon scheduler along the iterations with a signature (step, target_epsilon) -> epsilon. By default, constant_epsilon_scheduler() is used.

  • callback (Optional[Callable[[SemidiscreteState], None]]) – Callback with a signature (state) -> None that is called at every iteration.

Methods

epsilon_scheduler(target_epsilon)

Constant epsilon scheduler.

step(rng, state, prob, *[, compute_error, ...])

Perform one optimization step.

Attributes