ott.solvers.linear.sinkhorn_lr.LRSinkhorn#
- class ott.solvers.linear.sinkhorn_lr.LRSinkhorn(rank, gamma=10.0, gamma_rescale=True, epsilon=0.0, initializer=None, lse_mode=True, inner_iterations=10, use_danskin=True, kwargs_dys=None, progress_fn=None, **kwargs)[source]#
Low-Rank Sinkhorn solver for linear reg-OT problems.
The algorithm tries to minimize the low-rank optimal transport problem, a constrained formulation of the Kantorovich problem where the coupling variable is constrained to have a low-rank.
That problem is non-convex, and therefore any algorithm that tries to solve it requires special attention to initialization and control of convergence. Convergence is evaluated on successive evaluations of the objective whereas initializers are instance of the
LRInitializerclass.The algorithm is described in [Scetbon et al., 2021] and the implementation contained here is adapted from LOT.
- Parameters:
rank (
int) – Rank constraint on the coupling to minimize the linear OT problemgamma (
float) – The (inverse of) gradient step size used by mirror descent.gamma_rescale (
bool) – Whether to rescale \(\gamma\) every iteration as described in [Scetbon and Cuturi, 2022].epsilon (
float) – Entropic regularization added on top of low-rank problem.initializer (
Optional[LRInitializer]) – How to initialize the \(Q\), \(R\) and \(g\) factors.lse_mode (
bool) – Whether to run computations in LSE or kernel mode.inner_iterations (
int) – Number of inner iterations used by the algorithm before re-evaluating progress.use_danskin (
bool) – Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only True handled at this moment.progress_fn (
Optional[Callable[[Tuple[ndarray,ndarray,ndarray,LRSinkhornState]],None]]) – callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. Seedefault_progress_fn()for a basic implementation.kwargs_dys (
Optional[Mapping[str,Any]]) – Keyword arguments passed todykstra_update_lse(),dykstra_update_kernel()or one of the functions defined inott.solvers.linear, depending on whether the problem is balanced and on thelse_mode.
Methods
dykstra_update_kernel(k_q, k_r, k_g, gamma, ...)Run Dykstra's algorithm.
dykstra_update_lse(c_q, c_r, h, gamma, ot_prob)Run Dykstra's algorithm.
init_state(ot_prob, init)Return the initial state of the loop.
kernel_step(ot_prob, state, iteration)LR Sinkhorn Kernel update.
lse_step(ot_prob, state, iteration)LR Sinkhorn LSE update.
one_iteration(ot_prob, state, iteration, ...)Carries out one low-rank Sinkhorn iteration.
output_from_state(ot_prob, state)Create an output from a loop state.
Attributes
Powers used to compute the p-norm between marginal/target.
Upper bound on number of times inner_iterations are carried out.