ott.core.sinkhorn_lr.LRSinkhorn#

class ott.core.sinkhorn_lr.LRSinkhorn(rank=10, gamma=1.0, epsilon=0.0001, init_type='random', lse_mode=True, threshold=0.001, norm_error=1, inner_iterations=1, min_iterations=0, max_iterations=2000, use_danskin=True, implicit_diff=False, jit=True, rng_key=0, kwargs_dys=None)[source]#

A Low-Rank Sinkhorn solver for linear reg-OT problems.

A Low-Rank Sinkhorn solver takes a linear OT problem as an input, to return a LRSinkhornOutput object.

The algorithm is described in: Low-Rank Sinkhorn Factorization, Scetbon-Cuturi-Peyre, ICML’21. http://proceedings.mlr.press/v139/scetbon21a/scetbon21a.pdf

and the implementation contained here is adapted from that of: https://github.com/meyerscetbon/LOT

The algorithm minimizes a non-convex problem. It therefore requires special care to initialization and convergence. Initialization is random by default, and convergence evaluated on successive evaluations of the objective. The algorithm is only provided for the balanced case.

Parameters
  • rank (int) – the rank constraint on the coupling to minimize the linear OT problem

  • gamma (float) – the (inverse of) gradient stepsize used by mirror descent.

  • epsilon (float) – entropic regularization added on top of low-rank problem.

  • init_type (Literal[‘random’, ‘rank_2’]) – TODO.

  • lse_mode (bool) – whether to run computations in lse or kernel mode. At this moment, only lse_mode=True is implemented.

  • threshold (float) – convergence threshold, used to quantify whether two successive evaluations of the objective are (relatively) close enough to terminate.

  • norm_error (int) – norm used to quantify feasibility (deviation to marginals).

  • inner_iterations (int) – number of inner iterations used by the algorithm before reevaluating progress.

  • min_iterations (int) – min number of iterations before evaluating objective.

  • max_iterations (int) – max number of iterations allowed.

  • use_danskin (bool) – use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only True handled at this moment.

  • implicit_diff (bool) – whether to use implicit differentiation. Not implemented at this moment.

  • jit (bool) – jit by default iterations loop.

  • rng_key (int) – seed of random number generator to initialize the LR factors.

  • kwargs_dys (Optional[Mapping[str, Any]]) – keyword arguments passed onto dysktra_update().

Methods

dysktra_update(c_q, c_r, h, ot_prob, state, ...)

rtype

Tuple[ndarray, ndarray, ndarray]

init_state(ot_prob, init)

Return the initial state of the loop.

kernel_step(ot_prob, state, iteration)

LR Sinkhorn multiplicative update.

lr_costs(ot_prob, state, iteration)

rtype

Tuple[ndarray, ndarray, ndarray]

lse_step(ot_prob, state, iteration)

LR Sinkhorn LSE update.

one_iteration(ot_prob, state, iteration, ...)

Carries out one LR sinkhorn iteration.

output_from_state(ot_prob, state)

Create an output from a loop state.

recompute_couplings(f1, g1, c_q, f2, g2, c_r, h)

rtype

Tuple[ndarray, ndarray, ndarray]

Attributes

norm_error

rtype

Tuple[int]

outer_iterations

Upper bound on number of times inner_iterations are carried out.