ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein

ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein#

class ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein(rank, gamma=10.0, gamma_rescale=True, epsilon=0.0, initializer='random', lse_mode=True, use_danskin=True, implicit_diff=False, inner_iterations=2000, min_iterations=10000, max_iterations=100000, kwargs_dys=None, kwargs_init=None, progress_fn=None, **kwargs)[source]#

Low-rank Gromov-Wasserstein solver [Scetbon et al., 2023].

The algorithm minimizes a non-convex problem. It therefore requires special care to initialization and convergence. Convergence is evaluated on successive evaluations of the objective.

Warning

This solver only for the unbalanced case. Balanced case is implemented in GromovWasserstein and will be unified here in the future release.

Parameters:
  • rank (int) – Rank constraint on the coupling to minimize the linear OT problem

  • gamma (float) – The (inverse of) gradient step size used by mirror descent.

  • gamma_rescale (bool) – Whether to rescale \(\gamma\) every iteration as described in [Scetbon and Cuturi, 2022].

  • epsilon (float) – Entropic regularization added on top of low-rank problem.

  • initializer (Union[Literal['random', 'rank2', 'k-means', 'generalized-k-means'], LRInitializer]) – How to initialize the \(Q\), \(R\) and \(g\) factors.

  • lse_mode (bool) – Whether to run computations in LSE or kernel mode.

  • inner_iterations (int) – Number of inner iterations used by the algorithm before re-evaluating progress.

  • min_iterations (int) – The minimum number of low-rank Sinkhorn iterations carried out before the error is computed and monitored.

  • max_iterations (int) – The maximum number of low-rank Sinkhorn iterations.

  • use_danskin (bool) – Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only True handled at this moment.

  • implicit_diff (bool) – Whether to use implicit differentiation. Currently, only implicit_diff = False is implemented.

  • progress_fn (Optional[Callable[[Tuple[ndarray, ndarray, ndarray, LRGWState]], None]]) – callback function which gets called during the GW iterations, so the user can display the error at each iteration, e.g., using a progress bar. See default_progress_fn() for a basic implementation.

  • kwargs_dys (Optional[Mapping[str, Any]]) – Keyword arguments passed to dykstra_update_lse(), dykstra_update_kernel() or one of the functions defined in ott.solvers.linear, depending on the lse_mode.

  • kwargs_init (Optional[Mapping[str, Any]]) – Keyword arguments for LRInitializer.

  • kwargs (Any) – Keyword arguments for Sinkhorn.

Methods

create_initializer(prob)

Create a low-rank GW initializer.

dykstra_update_kernel(k_q, k_r, k_g, gamma, ...)

Run Dykstra's algorithm.

dykstra_update_lse(c_q, c_r, h, gamma, ot_prob)

Run Dykstra's algorithm.

init_state(ot_prob, init)

Return the initial state of the loop.

kernel_step(ot_prob, state, iteration)

Low-rank GW kernel update.

lse_step(ot_prob, state, iteration)

Low-rank GW LSE update.

one_iteration(ot_prob, state, iteration, ...)

Carries out one low-rank GW iteration.

output_from_state(ot_prob, state)

Create an output from a loop state.

Attributes

norm_error

Powers used to compute the p-norm between marginal/target.

outer_iterations

Upper bound on number of times inner_iterations are carried out.