ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein#

class ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein(*args, warm_start=None, unscale_last_linearization=False, quad_initializer=None, progress_fn=None, kwargs_init=None, **kwargs)[source]#

Gromov-Wasserstein solver [Peyré et al., 2016].

Parameters:
  • args (Any) – Positional arguments for WassersteinSolver.

  • warm_start (Optional[bool]) – Whether to initialize (low-rank) Sinkhorn calls using values from the previous iteration. If None, warm starts are not used for standard Sinkhorn, but used for low-rank Sinkhorn.

  • unscale_last_linearization (bool) – Whether to remove any scaling from the cost matrices of the last linearization stored in geom. This has the practical benefit that, while the OT coupling matrices obtained with GW might have been computed by re-scaling cost matrices for numerical stability, the last linearization stored in the geometry will be unscaled and recomputed with the original cost values.

  • quad_initializer (Union[Literal['random', 'rank2', 'k-means', 'generalized-k-means'], BaseQuadraticInitializer, None]) – Quadratic initializer. If the solver is entropic, QuadraticInitializer is always used. Otherwise, the quadratic initializer wraps the low-rank Sinkhorn initializers. If None, the low-rank initializer will be selected in a problem-specific manner. If both geom_xx and geom_yy are PointCloud or LRCGeometry, use KMeansInitializer. Otherwise, use RandomInitializer.

  • progress_fn (Optional[Callable[[Tuple[ndarray, ndarray, ndarray, GWState]], None]]) – callback function which gets called during the Gromov-Wasserstein iterations, so the user can display the error at each iteration, e.g., using a progress bar. See default_progress_fn() for a basic implementation.

  • kwargs_init (Optional[Mapping[str, Any]]) – Keyword arguments when creating the initializer.

  • kwargs (Any) – Keyword arguments for WassersteinSolver.

Methods

create_initializer(prob)

Create quadratic, possibly low-rank initializer.

init_state(prob, init, rng)

Initialize the state of the Gromov-Wasserstein iterations.

output_from_state(state)

Create an output from a loop state.

Attributes

is_low_rank

Whether the solver is low-rank.

warm_start

Whether to initialize (low-rank) Sinkhorn using previous solutions.