ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein

ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein#

class ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein(*args, warm_start=None, relative_epsilon=None, quad_initializer=None, progress_fn=None, kwargs_init=None, **kwargs)[source]#

Gromov-Wasserstein solver [Peyré et al., 2016].

See also

Low-rank Gromov-Wasserstein [Scetbon et al., 2023] is implemented in LRGromovWasserstein.

Parameters:
  • args (Any) – Positional arguments for WassersteinSolver.

  • warm_start (Optional[bool]) – Whether to initialize Sinkhorn calls using values from the previous iteration. If None, warm starts are not used for standard Sinkhorn.

  • relative_epsilon (Optional[bool]) – Whether to use relative epsilon in the linearized geometry.

  • quad_initializer (Union[Literal['random', 'rank2', 'k-means', 'generalized-k-means'], BaseQuadraticInitializer, None]) – Quadratic initializer. If the solver is entropic, QuadraticInitializer is always used.

  • progress_fn (Optional[Callable[[Tuple[ndarray, ndarray, ndarray, GWState]], None]]) – callback function which gets called during the Gromov-Wasserstein iterations, so the user can display the error at each iteration, e.g., using a progress bar. See default_progress_fn() for a basic implementation.

  • kwargs_init (Optional[Mapping[str, Any]]) – Keyword arguments when creating the initializer.

  • kwargs (Any) – Keyword arguments for WassersteinSolver.

Methods

create_initializer(prob)

Create quadratic, possibly low-rank initializer.

init_state(prob, init, rng)

Initialize the state of the Gromov-Wasserstein iterations.

output_from_state(state)

Create an output from a loop state.

Attributes

is_low_rank

Whether the solver is low-rank.

warm_start

Whether to initialize Sinkhorn using previous solutions.