ott.solvers.linear.sinkhorn_lr.LRSinkhorn

Contents

ott.solvers.linear.sinkhorn_lr.LRSinkhorn#

class ott.solvers.linear.sinkhorn_lr.LRSinkhorn(rank, gamma=10.0, gamma_rescale=True, epsilon=0.0, initializer='random', lse_mode=True, inner_iterations=10, use_danskin=True, kwargs_dys=None, kwargs_init=None, progress_fn=None, **kwargs)[source]#

Low-Rank Sinkhorn solver for linear reg-OT problems.

The algorithm is described in [Scetbon et al., 2021] and the implementation contained here is adapted from LOT.

The algorithm minimizes a non-convex problem. It therefore requires special care to initialization and convergence. Convergence is evaluated on successive evaluations of the objective.

Parameters:
  • rank (int) – Rank constraint on the coupling to minimize the linear OT problem

  • gamma (float) – The (inverse of) gradient step size used by mirror descent.

  • gamma_rescale (bool) – Whether to rescale \(\gamma\) every iteration as described in [Scetbon and Cuturi, 2022].

  • epsilon (float) – Entropic regularization added on top of low-rank problem.

  • initializer (Union[Literal['random', 'rank2', 'k-means', 'generalized-k-means'], LRInitializer]) – How to initialize the \(Q\), \(R\) and \(g\) factors.

  • lse_mode (bool) – Whether to run computations in LSE or kernel mode.

  • inner_iterations (int) – Number of inner iterations used by the algorithm before re-evaluating progress.

  • use_danskin (bool) – Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only True handled at this moment.

  • progress_fn (Optional[Callable[[Tuple[ndarray, ndarray, ndarray, LRSinkhornState]], None]]) – callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. See default_progress_fn() for a basic implementation.

  • kwargs_dys (Optional[Mapping[str, Any]]) – Keyword arguments passed to dykstra_update_lse(), dykstra_update_kernel() or one of the functions defined in ott.solvers.linear, depending on whether the problem is balanced and on the lse_mode.

  • kwargs_init (Optional[Mapping[str, Any]]) – Keyword arguments for LRInitializer.

  • kwargs (Any) – Keyword arguments for Sinkhorn.

Methods

create_initializer(prob)

Create a low-rank Sinkhorn initializer.

dykstra_update_kernel(k_q, k_r, k_g, gamma, ...)

Run Dykstra's algorithm.

dykstra_update_lse(c_q, c_r, h, gamma, ot_prob)

Run Dykstra's algorithm.

init_state(ot_prob, init)

Return the initial state of the loop.

kernel_step(ot_prob, state, iteration)

LR Sinkhorn Kernel update.

lse_step(ot_prob, state, iteration)

LR Sinkhorn LSE update.

one_iteration(ot_prob, state, iteration, ...)

Carries out one low-rank Sinkhorn iteration.

output_from_state(ot_prob, state)

Create an output from a loop state.

Attributes

norm_error

Powers used to compute the p-norm between marginal/target.

outer_iterations

Upper bound on number of times inner_iterations are carried out.