ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein

ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein#

class ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein(linear_solver, epsilon=1.0, relative_epsilon=None, initializer=None, warm_start=False, progress_fn=None, **kwargs)[source]#

Entropic Gromov-Wasserstein solver [Peyré et al., 2016].

See also

Low-rank Gromov-Wasserstein [Scetbon et al., 2023] is implemented in LRGromovWasserstein.

Parameters:
  • linear_solver (Sinkhorn) – Linear OT solver.

  • epsilon (float) – Entropic regularization.

  • relative_epsilon (Optional[Literal['mean', 'std']]) – Whether to use relative epsilon in the linearized geometry.

  • initializer (Optional[BaseQuadraticInitializer]) – Quadratic initializer. If None, use QuadraticInitializer.

  • warm_start (bool) – Whether to initialize Sinkhorn calls with the values from the previous iteration.

  • progress_fn (Optional[Callable[[Tuple[ndarray, ndarray, ndarray, GWState]], None]]) – callback function which gets called during the Gromov-Wasserstein iterations, so the user can display the error at each iteration, e.g., using a progress bar. See default_progress_fn() for a basic implementation.

  • kwargs (Any) – Keyword arguments for WassersteinSolver.

Methods

init_state(prob, init)

Initialize the state of the Gromov-Wasserstein iterations.

output_from_state(state)

Create an output from a loop state.

Attributes

is_low_rank

Whether the solver is low-rank.

rank

Rank of the linear OT solver.