ott.solvers.linear.semidiscrete.SemidiscreteSolver#
- class ott.solvers.linear.semidiscrete.SemidiscreteSolver(*, num_iterations, batch_size, optimizer, error_eval_every=1000, error_batch_size=None, error_num_repeats=16, threshold=0.001, potential_ema=0.99, epsilon_scheduler=<function constant_epsilon_scheduler>, callback=None)[source]#
Semidiscrete optimal transport solver.
- Parameters:
num_iterations (
int) – Number of iterations.batch_size (
int) – Number of points to sample at each iteration.optimizer (
GradientTransformation) – Optimizer.error_eval_every (
int) – Compute the chi-squared error everyerror_eval_everyiterations.error_batch_size (
Optional[int]) – Batch size to use when computing the marginal chi-squared error. IfNone, usebatch_size.error_num_repeats (
int) – Number of repeats used to estimate the marginal chi-squared error, set to sixteen by default.threshold (
float) – Convergence threshold for the marginal chi-squared error.potential_ema (
float) – Exponential moving average of the dual potential.epsilon_scheduler (
Callable[[Array,Array],Array]) – Epsilon scheduler along the iterations with a signature(step, target_epsilon) -> epsilon. By default,constant_epsilon_scheduler()is used.callback (
Optional[Callable[[SemidiscreteState],None]]) – Callback with a signature(state) -> Nonethat is called at every iteration.
Methods
epsilon_scheduler(target_epsilon)Constant epsilon scheduler.
step(rng, state, prob, *[, compute_error, ...])Perform one optimization step.
Attributes