class ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein(rank, gamma=10.0, gamma_rescale=True, epsilon=0.0, initializer='random', lse_mode=True, inner_iterations=10, use_danskin=True, implicit_diff=False, kwargs_dys=None, kwargs_init=None, progress_fn=None, **kwargs)[source]#

Low-rank Gromov-Wasserstein solver [Scetbon et al., 2023].

The algorithm minimizes a non-convex problem. It therefore requires special care to initialization and convergence. Convergence is evaluated on successive evaluations of the objective.


This solver only for the unbalanced case. Balanced case is implemented in GromovWasserstein and will be unified here in the future release.

  • rank (int) – Rank constraint on the coupling to minimize the linear OT problem

  • gamma (float) – The (inverse of) gradient step size used by mirror descent.

  • gamma_rescale (bool) – Whether to rescale \(\gamma\) every iteration as described in [Scetbon and Cuturi, 2022].

  • epsilon (float) – Entropic regularization added on top of low-rank problem.

  • initializer (Union[Literal['random', 'rank2', 'k-means', 'generalized-k-means'], LRInitializer]) – How to initialize the \(Q\), \(R\) and \(g\) factors.

  • lse_mode (bool) – Whether to run computations in LSE or kernel mode.

  • inner_iterations (int) – Number of inner iterations used by the algorithm before re-evaluating progress.

  • use_danskin (bool) – Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only True handled at this moment.

  • implicit_diff (bool) – Whether to use implicit differentiation. Currently, only implicit_diff = False is implemented.

  • progress_fn (Optional[Callable[[Tuple[ndarray, ndarray, ndarray, LRGWState]], None]]) – callback function which gets called during the GW iterations, so the user can display the error at each iteration, e.g., using a progress bar. See default_progress_fn() for a basic implementation.

  • kwargs_dys (Optional[Mapping[str, Any]]) – Keyword arguments passed to dykstra_update_lse(), dykstra_update_kernel() or one of the functions defined in ott.solvers.linear, depending on the lse_mode.

  • kwargs_init (Optional[Mapping[str, Any]]) – Keyword arguments for LRInitializer.

  • kwargs (Any) – Keyword arguments for Sinkhorn.



Create a low-rank GW initializer.

dykstra_update_kernel(k_q, k_r, k_g, gamma, ...)

Run Dykstra's algorithm.

dykstra_update_lse(c_q, c_r, h, gamma, ot_prob)

Run Dykstra's algorithm.

init_state(ot_prob, init)

Return the initial state of the loop.

kernel_step(ot_prob, state, iteration)

Low-rank GW kernel update.

lse_step(ot_prob, state, iteration)

Low-rank GW LSE update.

one_iteration(ot_prob, state, iteration, ...)

Carries out one low-rank GW iteration.

output_from_state(ot_prob, state)

Create an output from a loop state.



Powers used to compute the p-norm between marginal/target.


Upper bound on number of times inner_iterations are carried out.