ott.experimental.mmsinkhorn.MMSinkhorn

Contents

ott.experimental.mmsinkhorn.MMSinkhorn#

class ott.experimental.mmsinkhorn.MMSinkhorn(threshold=0.001, norm_error=1.0, inner_iterations=10, min_iterations=0, max_iterations=2000, use_danskin=True)[source]#

Multimarginal Sinkhorn solver, aligns \(k \,d\)-dim point clouds.

This solver implements the entropic multimarginal solver presented in [Benamou et al., 2015] and described in [Piran et al., 2024], Algorithm 1. The current implementation follows largely the template of the Sinkhorn solver, with a much reduced set of hyperparameters, controlling the number of iterations and convergence threshold, along with the application of the [Danskin, 1967] theorem to instantiate the OT cost. The iterations are done by default in log-space.

Parameters:
  • threshold (float) – tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution.

  • norm_error (float) – power used to define p-norm of error for marginal/target.

  • inner_iterations (int) – the Sinkhorn error is not recomputed at each iteration but every inner_iterations instead.

  • min_iterations (int) – the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.

  • max_iterations (int) – the maximum number of Sinkhorn iterations. If max_iterations is equal to min_iterations, Sinkhorn iterations are run by default using a scan() loop rather than a custom, unroll-able while_loop() that monitors convergence. In that case the error is not monitored and the converged flag will return False as a consequence.

  • use_danskin (bool) – when True, it is assumed the entropy regularized cost is evaluated using optimal potentials that are frozen, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid mathematically when the algorithm has converged with a low tolerance.

Methods

init_state(n_s)

Return the initial state of the loop.

Attributes

outer_iterations

Upper bound on number of times inner_iterations are carried out.