# Source code for ott.solvers.quadratic.gromov_wasserstein

```# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
Any,
Callable,
Dict,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)

import jax
import jax.numpy as jnp
import numpy as np

from ott import utils
from ott.geometry import geometry
from ott.initializers.quadratic import initializers as quad_initializers
from ott.math import fixed_point_loop
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers import was_solver
from ott.solvers.linear import sinkhorn, sinkhorn_lr

__all__ = ["GromovWasserstein", "GWOutput"]

LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]

ProgressCallbackFn_t = Callable[
[Tuple[np.ndarray, np.ndarray, np.ndarray, "GWState"]], None]

[docs]
class GWOutput(NamedTuple):
"""Holds the output of the Gromov-Wasserstein solver.

Args:
costs: Holds the sequence of regularized GW costs seen through the outer
loop of the solver.
linear_convergence: Holds the sequence of bool convergence flags of the
inner Sinkhorn iterations.
converged: Convergence flag for the outer GW iterations.
errors: Holds sequence of vectors of errors of the Sinkhorn algorithm
at each iteration.
linear_state: State used to solve and store solutions to the local
linearization of GW.
geom: The geometry underlying the local linearization.
old_transport_mass: Holds total mass of transport at previous iteration.
"""

costs: Optional[jnp.ndarray] = None
linear_convergence: Optional[jnp.ndarray] = None
converged: bool = False
errors: Optional[jnp.ndarray] = None
linear_state: Optional[LinearOutput] = None
geom: Optional[geometry.Geometry] = None
# Intermediate values.
old_transport_mass: float = 1.0

[docs]
def set(self, **kwargs: Any) -> "GWOutput":
"""Return a copy of self, possibly with overwrites."""
return self._replace(**kwargs)

@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix."""
return self._rescale_factor * self.linear_state.matrix

[docs]
def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
"""Apply the transport to an array; axis=1 for its transpose."""
return self._rescale_factor * self.linear_state.apply(inputs, axis=axis)

@property
def reg_gw_cost(self) -> float:
"""Regularized optimal transport cost of the linearization."""
return self.linear_state.reg_ot_cost

@property
def _rescale_factor(self) -> float:
return jnp.sqrt(self.old_transport_mass / self.linear_state.transport_mass)

@property
def primal_cost(self) -> float:
"""Return transport cost of current linear OT solution at geometry."""
return self.linear_state.transport_cost_at_geom(other_geom=self.geom)

@property
def n_iters(self) -> int:  # noqa: D102
if self.errors is None:
return -1
return jnp.sum(self.errors[:, 0] != -1)

class GWState(NamedTuple):
"""State of the Gromov-Wasserstein solver.

Attributes:
costs: Holds the sequence of regularized GW costs seen through the outer
loop of the solver.
linear_convergence: Holds the sequence of bool convergence flags of the
inner Sinkhorn iterations.
linear_state: State used to solve and store solutions to the local
linearization of GW.
linear_pb: Local linearization of the quadratic GW problem.
old_transport_mass: Intermediary value of the mass of the transport matrix.
rngs: Random keys passed to low-rank initializers at every GW iteration
when not using warm start.
errors: Holds sequence of vectors of errors of the Sinkhorn algorithm
at each iteration.
"""

costs: jnp.ndarray
linear_convergence: jnp.ndarray
linear_state: LinearOutput
linear_pb: linear_problem.LinearProblem
old_transport_mass: float
rngs: Optional[jax.Array] = None
errors: Optional[jnp.ndarray] = None

def set(self, **kwargs: Any) -> "GWState":
"""Return a copy of self, possibly with overwrites."""
return self._replace(**kwargs)

def update(  # noqa: D102
self, iteration: int, linear_sol: LinearOutput,
linear_pb: linear_problem.LinearProblem, store_errors: bool,
old_transport_mass: float
) -> "GWState":
costs = self.costs.at[iteration].set(linear_sol.reg_ot_cost)
errors = None
if store_errors and self.errors is not None:
errors = self.errors.at[iteration, :].set(linear_sol.errors)
linear_convergence = self.linear_convergence.at[iteration].set(
linear_sol.converged
)

return self.set(
linear_state=linear_sol,
linear_pb=linear_pb,
costs=costs,
linear_convergence=linear_convergence,
errors=errors,
old_transport_mass=old_transport_mass
)

[docs]
@jax.tree_util.register_pytree_node_class
class GromovWasserstein(was_solver.WassersteinSolver):
"""Gromov-Wasserstein solver :cite:`peyre:16`.

.. seealso::
Low-rank Gromov-Wasserstein :cite:`scetbon:23` is implemented in
:class:`~ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein`.

Args:
args: Positional arguments for
:class:`~ott.solvers.was_solver.WassersteinSolver`.
warm_start: Whether to initialize Sinkhorn calls using values
from the previous iteration. If :obj:`None`, warm starts are not used for
standard Sinkhorn.
relative_epsilon: Whether to use relative epsilon in the linearized
geometry.
quad_initializer: Quadratic initializer. If the solver is entropic,
:class:`~ott.initializers.quadratic.initializers.QuadraticInitializer`
is always used.
progress_fn: callback function which gets called during the
Gromov-Wasserstein iterations, so the user can display the error at each
iteration, e.g., using a progress bar.
See :func:`~ott.utils.default_progress_fn` for a basic implementation.
kwargs_init: Keyword arguments when creating the initializer.
kwargs: Keyword arguments for
:class:`~ott.solvers.was_solver.WassersteinSolver`.
"""

def __init__(
self,
*args: Any,
warm_start: Optional[bool] = None,
relative_epsilon: Optional[bool] = None,
quad_initializer: Optional[
Union[Literal["random", "rank2", "k-means", "generalized-k-means"],
quad_initializers.BaseQuadraticInitializer]] = None,
progress_fn: Optional[ProgressCallbackFn_t] = None,
kwargs_init: Optional[Mapping[str, Any]] = None,
**kwargs: Any
):
super().__init__(*args, **kwargs)
assert not self.is_low_rank, \
"For low-rank GW, use " \
"`ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein`."
self._warm_start = warm_start
self.relative_epsilon = relative_epsilon
self.quad_initializer = quad_initializer
self.progress_fn = progress_fn
self.kwargs_init = {} if kwargs_init is None else kwargs_init

def __call__(
self,
prob: quadratic_problem.QuadraticProblem,
init: Optional[linear_problem.LinearProblem] = None,
rng: Optional[jax.Array] = None,
**kwargs: Any,
) -> GWOutput:
"""Run the Gromov-Wasserstein solver.

Args:
prob: Quadratic OT problem.
init: Initial linearization of the quadratic problem. If `None`, it will
be computed using the initializer.
rng: Random number key.
kwargs: Keyword arguments used when calling the initializer.

Returns:
The Gromov-Wasserstein output.
"""
rng = utils.default_prng_key(rng)
rng1, rng2 = jax.random.split(rng, 2)

if prob._is_low_rank_convertible:
prob = prob.to_low_rank()

if init is None:
initializer = self.create_initializer(prob)
init = initializer(
prob,
epsilon=self.epsilon,
rng=rng1,
relative_epsilon=self.relative_epsilon,
**kwargs
)

out = iterations(self, prob, init, rng2)
# TODO(lpapaxanthoos): remove stop_gradient when using backprop
if self.is_low_rank:
linearization = prob.update_lr_linearization(
jax.lax.stop_gradient(out.linear_state),
relative_epsilon=self.relative_epsilon,
)
else:
linearization = prob.update_linearization(
jax.lax.stop_gradient(out.linear_state),
epsilon=self.epsilon,
old_transport_mass=jax.lax.stop_gradient(out.old_transport_mass),
relative_epsilon=self.relative_epsilon,
)

linear_state = out.linear_state.set_cost(linearization, True, True)
iteration = jnp.sum(out.costs != -1)
converged = jnp.logical_and(
iteration < self.max_iterations, jnp.all(out.linear_convergence)
)
return out.set(
linear_state=linear_state, geom=linearization.geom, converged=converged
)

[docs]
def init_state(
self,
prob: quadratic_problem.QuadraticProblem,
init: linear_problem.LinearProblem,
rng: jax.Array,
) -> GWState:
"""Initialize the state of the Gromov-Wasserstein iterations.

Args:
prob: Quadratic OT problem.
init: Initial linearization of the quadratic problem.
rng: Random key for low-rank initializers. Only used when
:attr:`warm_start` is `False`.

Returns:
The initial Gromov-Wasserstein state.
"""
linear_state = self.linear_ot_solver(init)
num_iter = self.max_iterations
transport_mass = prob.init_transport_mass()
if self.store_inner_errors:
errors = -jnp.ones((num_iter, self.linear_ot_solver.outer_iterations))
else:
errors = None

return GWState(
costs=-jnp.ones((num_iter,)),
linear_convergence=-jnp.ones((num_iter,)),
linear_state=linear_state,
linear_pb=init,
old_transport_mass=transport_mass,
rngs=jax.random.split(rng, num_iter),
errors=errors,
)

[docs]
def output_from_state(
self,
state: GWState,
) -> GWOutput:
"""Create an output from a loop state.

Arguments:
state: A GWState.

Returns:
A GWOutput.
"""
return GWOutput(
costs=state.costs,
linear_convergence=state.linear_convergence,
errors=state.errors,
linear_state=state.linear_state,
geom=state.linear_pb.geom,
old_transport_mass=state.old_transport_mass
)

[docs]
def create_initializer(
self, prob: quadratic_problem.QuadraticProblem
) -> quad_initializers.BaseQuadraticInitializer:
"""Create quadratic, possibly low-rank initializer.

Args:
prob: Quadratic OT problem used to determine the initializer.

Returns:
The initializer.
"""
if isinstance(
self.quad_initializer, quad_initializers.BaseQuadraticInitializer
):
return self.quad_initializer
# no other options implemented, use the default
return quad_initializers.QuadraticInitializer(**self.kwargs_init)

@property
def warm_start(self) -> bool:
"""Whether to initialize Sinkhorn using previous solutions."""
return self.is_low_rank if self._warm_start is None else self._warm_start

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:  # noqa: D102
children, aux_data = super().tree_flatten()
aux_data["warm_start"] = self._warm_start
aux_data["progress_fn"] = self.progress_fn
aux_data["relative_epsilon"] = self.relative_epsilon
aux_data["quad_initializer"] = self.quad_initializer
aux_data["kwargs_init"] = self.kwargs_init
return children, aux_data

def iterations(
solver: GromovWasserstein,
prob: quadratic_problem.QuadraticProblem,
init: linear_problem.LinearProblem,
rng: jax.Array,
) -> GWOutput:
"""Jittable Gromov-Wasserstein outer loop."""

def cond_fn(
iteration: int, solver: GromovWasserstein, state: GWState
) -> bool:
return solver._continue(state, iteration)

def body_fn(
iteration: int, solver: GromovWasserstein, state: GWState,
compute_error: bool
) -> GWState:
del compute_error  # always assumed true for the outer loop of GW

lin_state = state.linear_state
if solver.is_low_rank:
rng = state.rngs[iteration]
init = (lin_state.q, lin_state.r,
lin_state.g) if solver.warm_start else (None, None, None)
linear_pb = prob.update_lr_linearization(
state.linear_state, relative_epsilon=solver.relative_epsilon
)
out = solver.linear_ot_solver(linear_pb, init=init, rng=rng)
else:
init = (lin_state.f, lin_state.g) if solver.warm_start else (None, None)
linear_pb = prob.update_linearization(
lin_state,
solver.epsilon,
state.old_transport_mass,
relative_epsilon=solver.relative_epsilon,
)
out = solver.linear_ot_solver(linear_pb, init=init)

old_transport_mass = jax.lax.stop_gradient(
state.linear_state.transport_mass
)
new_state = state.update(
iteration, out, linear_pb, solver.store_inner_errors, old_transport_mass
)

# Inner iterations is currently fixed to 1.
inner_iterations = 1
if solver.progress_fn is not None:
jax.debug.callback(
solver.progress_fn,
(iteration, inner_iterations, solver.max_iterations, state)
)

return new_state

state = fixed_point_loop.fixpoint_iter(
cond_fn=cond_fn,
body_fn=body_fn,
min_iterations=solver.min_iterations,
max_iterations=solver.max_iterations,
inner_iterations=1,
constants=solver,
state=solver.init_state(prob, init, rng=rng)
)

return solver.output_from_state(state)
```