ott.solvers.linear.sinkhorn_lr.LRSinkhorn#
- class ott.solvers.linear.sinkhorn_lr.LRSinkhorn(rank, gamma=10.0, gamma_rescale=True, epsilon=0.0, initializer='random', lse_mode=True, inner_iterations=10, use_danskin=True, implicit_diff=False, kwargs_dys=None, kwargs_init=None, progress_fn=None, **kwargs)[source]#
A Low-Rank Sinkhorn solver for linear reg-OT problems.
The algorithm is described in [Scetbon et al., 2021] and the implementation contained here is adapted from LOT.
The algorithm minimizes a non-convex problem. It therefore requires special care to initialization and convergence. Convergence is evaluated on successive evaluations of the objective. The algorithm is only provided for the balanced case.
- Parameters:
rank (
int
) – Rank constraint on the coupling to minimize the linear OT problemgamma (
float
) – The (inverse of) gradient step size used by mirror descent.gamma_rescale (
bool
) – Whether to rescale \(\gamma\) every iteration as described in [Scetbon and Cuturi, 2022].epsilon (
float
) – Entropic regularization added on top of low-rank problem.initializer (
Union
[Literal
['random'
,'rank2'
,'k-means'
,'generalized-k-means'
],LRInitializer
,None
]) – How to initialize the \(Q\), \(R\) and \(g\) factors. Valid options are ‘random’, ‘rank2’, ‘k-means’, and ‘generalized-k-means. If None,KMeansInitializer
is used when the linear problem’s geometry isPointCloud
orLRCGeometry
. Otherwise, useRandomInitializer
.lse_mode (
bool
) – Whether to run computations in lse or kernel mode. At the moment, onlylse_mode = True
is implemented.inner_iterations (
int
) – Number of inner iterations used by the algorithm before re-evaluating progress.use_danskin (
bool
) – Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only True handled at this moment.implicit_diff (
bool
) – Whether to use implicit differentiation. Currently, onlyimplicit_diff = False
is implemented.progress_fn (
Optional
[Callable
[[Tuple
[ndarray
,ndarray
,ndarray
,LRSinkhornState
]],None
]]) – callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. Seedefault_progress_fn()
for a basic implementation.kwargs_dys (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed todykstra_update()
.kwargs_init (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments forLRInitializer
.
Methods
create_initializer
(prob)Create a low-rank Sinkhorn initializer.
dykstra_update
(c_q, c_r, h, gamma, ot_prob)Run Dykstra's algorithm.
init_state
(ot_prob, init)Return the initial state of the loop.
kernel_step
(ot_prob, state, iteration)Not implemented.
lse_step
(ot_prob, state, iteration)LR Sinkhorn LSE update.
one_iteration
(ot_prob, state, iteration, ...)Carries out one low-rank Sinkhorn iteration.
output_from_state
(ot_prob, state)Create an output from a loop state.
Attributes
Whether entropy regularization is used.
Powers used to compute the p-norm between marginal/target.
Upper bound on number of times inner_iterations are carried out.