# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
Any,
Callable,
Dict,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import jax
import jax.numpy as jnp
import numpy as np
from ott.geometry import geometry
from ott.initializers.quadratic import initializers as quad_initializers
from ott.math import fixed_point_loop
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers import was_solver
from ott.solvers.linear import sinkhorn, sinkhorn_lr
__all__ = ["GromovWasserstein", "GWOutput"]
LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]
ProgressCallbackFn = Callable[
[Tuple[np.ndarray, np.ndarray, np.ndarray, "GWState"]], None]
[docs]
class GWOutput(NamedTuple):
"""Holds the output of the Gromov-Wasserstein solver.
Args:
costs: Holds the sequence of regularized GW costs seen through the outer
loop of the solver.
linear_convergence: Holds the sequence of bool convergence flags of the
inner Sinkhorn iterations.
converged: Convergence flag for the outer GW iterations.
errors: Holds sequence of vectors of errors of the Sinkhorn algorithm
at each iteration.
linear_state: State used to solve and store solutions to the local
linearization of GW.
geom: The geometry underlying the local linearization.
old_transport_mass: Holds total mass of transport at previous iteration.
"""
costs: Optional[jnp.ndarray] = None
linear_convergence: Optional[jnp.ndarray] = None
converged: bool = False
errors: Optional[jnp.ndarray] = None
linear_state: Optional[LinearOutput] = None
geom: Optional[geometry.Geometry] = None
# Intermediate values.
old_transport_mass: float = 1.0
[docs]
def set(self, **kwargs: Any) -> "GWOutput":
"""Return a copy of self, possibly with overwrites."""
return self._replace(**kwargs)
@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix."""
return self._rescale_factor * self.linear_state.matrix
[docs]
def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
"""Apply the transport to an array; axis=1 for its transpose."""
return self._rescale_factor * self.linear_state.apply(inputs, axis=axis)
@property
def reg_gw_cost(self) -> float:
"""Regularized optimal transport cost of the linearization."""
return self.linear_state.reg_ot_cost
@property
def _rescale_factor(self) -> float:
return jnp.sqrt(self.old_transport_mass / self.linear_state.transport_mass)
@property
def primal_cost(self) -> float:
"""Return transport cost of current linear OT solution at geometry."""
return self.linear_state.transport_cost_at_geom(other_geom=self.geom)
@property
def n_iters(self) -> int: # noqa: D102
if self.errors is None:
return -1
return jnp.sum(self.errors[:, 0] != -1)
class GWState(NamedTuple):
"""State of the Gromov-Wasserstein solver.
Attributes:
costs: Holds the sequence of regularized GW costs seen through the outer
loop of the solver.
linear_convergence: Holds the sequence of bool convergence flags of the
inner Sinkhorn iterations.
linear_state: State used to solve and store solutions to the local
linearization of GW.
linear_pb: Local linearization of the quadratic GW problem.
old_transport_mass: Intermediary value of the mass of the transport matrix.
errors: Holds sequence of vectors of errors of the Sinkhorn algorithm
at each iteration.
"""
costs: jnp.ndarray
linear_convergence: jnp.ndarray
linear_state: LinearOutput
linear_pb: linear_problem.LinearProblem
old_transport_mass: float
errors: Optional[jnp.ndarray] = None
def set(self, **kwargs: Any) -> "GWState":
"""Return a copy of self, possibly with overwrites."""
return self._replace(**kwargs)
def update( # noqa: D102
self, iteration: int, linear_sol: LinearOutput,
linear_pb: linear_problem.LinearProblem, store_errors: bool,
old_transport_mass: float
) -> "GWState":
costs = self.costs.at[iteration].set(linear_sol.reg_ot_cost)
errors = None
if store_errors and self.errors is not None:
errors = self.errors.at[iteration, :].set(linear_sol.errors)
linear_convergence = self.linear_convergence.at[iteration].set(
linear_sol.converged
)
return self.set(
linear_state=linear_sol,
linear_pb=linear_pb,
costs=costs,
linear_convergence=linear_convergence,
errors=errors,
old_transport_mass=old_transport_mass
)
[docs]
@jax.tree_util.register_pytree_node_class
class GromovWasserstein(was_solver.WassersteinSolver):
"""Entropic Gromov-Wasserstein solver :cite:`peyre:16`.
.. seealso::
Low-rank Gromov-Wasserstein :cite:`scetbon:23` is implemented in
:class:`~ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein`.
Args:
linear_solver: Linear OT solver.
epsilon: Entropic regularization.
relative_epsilon: Whether to use relative epsilon in the linearized
geometry.
initializer: Quadratic initializer. If :obj:`None`, use
:class:`~ott.initializers.quadratic.initializers.QuadraticInitializer`.
warm_start: Whether to initialize Sinkhorn calls with the values
from the previous iteration.
progress_fn: callback function which gets called during the
Gromov-Wasserstein iterations, so the user can display the error at each
iteration, e.g., using a progress bar.
See :func:`~ott.utils.default_progress_fn` for a basic implementation.
kwargs: Keyword arguments for
:class:`~ott.solvers.was_solver.WassersteinSolver`.
"""
def __init__(
self,
linear_solver: sinkhorn.Sinkhorn,
epsilon: float = 1.0,
relative_epsilon: Optional[Literal["mean", "std"]] = None,
initializer: Optional[quad_initializers.BaseQuadraticInitializer] = None,
warm_start: bool = False,
progress_fn: Optional[ProgressCallbackFn] = None,
**kwargs: Any
):
super().__init__(linear_solver, **kwargs)
self.epsilon = epsilon
self.relative_epsilon = relative_epsilon
self.initializer = quad_initializers.QuadraticInitializer(
) if initializer is None else initializer
self.warm_start = warm_start
self.progress_fn = progress_fn
def __call__(
self,
prob: quadratic_problem.QuadraticProblem,
init: Optional[linear_problem.LinearProblem] = None,
**kwargs: Any,
) -> GWOutput:
"""Run the Gromov-Wasserstein solver.
Args:
prob: Quadratic OT problem.
init: Initial linearization of the quadratic problem.
If :obj:`None`, use the initializer.
kwargs: Keyword arguments for the initializer.
Returns:
The Gromov-Wasserstein output.
"""
if prob._is_low_rank_convertible:
prob = prob.to_low_rank()
if init is None:
init = self.initializer(
prob,
epsilon=self.epsilon,
relative_epsilon=self.relative_epsilon,
**kwargs,
)
out = iterations(self, prob, init)
# TODO(lpapaxanthoos): remove stop_gradient when using backprop
linearization = prob.update_linearization(
jax.lax.stop_gradient(out.linear_state),
epsilon=self.epsilon,
old_transport_mass=jax.lax.stop_gradient(out.old_transport_mass),
relative_epsilon=self.relative_epsilon,
)
linear_state = out.linear_state.set_cost(linearization, True, True)
iteration = jnp.sum(out.costs != -1)
converged = jnp.logical_and(
iteration < self.max_iterations,
jnp.nanmean(out.linear_convergence) == 1.0
)
return out.set(
linear_state=linear_state, geom=linearization.geom, converged=converged
)
[docs]
def init_state(
self,
prob: quadratic_problem.QuadraticProblem,
init: linear_problem.LinearProblem,
) -> GWState:
"""Initialize the state of the Gromov-Wasserstein iterations.
Args:
prob: Quadratic OT problem.
init: Initial linearization of the quadratic problem.
Returns:
The initial Gromov-Wasserstein state.
"""
linear_state = self.linear_solver(init)
num_iter = self.max_iterations
transport_mass = prob.init_transport_mass()
if self.store_inner_errors:
errors = -jnp.ones((num_iter, self.linear_solver.outer_iterations))
else:
errors = None
return GWState(
costs=-jnp.ones((num_iter,)),
linear_convergence=jnp.full((num_iter,), fill_value=jnp.nan),
linear_state=linear_state,
linear_pb=init,
old_transport_mass=transport_mass,
errors=errors,
)
[docs]
def output_from_state(
self,
state: GWState,
) -> GWOutput:
"""Create an output from a loop state.
Arguments:
state: A GWState.
Returns:
A GWOutput.
"""
return GWOutput(
costs=state.costs,
linear_convergence=state.linear_convergence,
errors=state.errors,
linear_state=state.linear_state,
geom=state.linear_pb.geom,
old_transport_mass=state.old_transport_mass
)
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
children, aux_data = super().tree_flatten()
aux_data["epsilon"] = self.epsilon
aux_data["relative_epsilon"] = self.relative_epsilon
aux_data["initializer"] = self.initializer
aux_data["warm_start"] = self.warm_start
aux_data["progress_fn"] = self.progress_fn
return children, aux_data
@classmethod
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "GromovWasserstein":
linear_solver, threshold = children
return cls(linear_solver, threshold=threshold, **aux_data)
def iterations(
solver: GromovWasserstein,
prob: quadratic_problem.QuadraticProblem,
init: linear_problem.LinearProblem,
) -> GWOutput:
"""Jittable Gromov-Wasserstein outer loop."""
def cond_fn(
iteration: int, solver: GromovWasserstein, state: GWState
) -> bool:
return solver._continue(state, iteration)
def body_fn(
iteration: int, solver: GromovWasserstein, state: GWState,
compute_error: bool
) -> GWState:
del compute_error # always assumed true for the outer loop of GW
lin_state = state.linear_state
init = (lin_state.f, lin_state.g) if solver.warm_start else None
linear_pb = prob.update_linearization(
lin_state,
solver.epsilon,
state.old_transport_mass,
relative_epsilon=solver.relative_epsilon,
)
out = solver.linear_solver(linear_pb, init=init)
old_transport_mass = jax.lax.stop_gradient(
state.linear_state.transport_mass
)
new_state = state.update(
iteration, out, linear_pb, solver.store_inner_errors, old_transport_mass
)
# Inner iterations is currently fixed to 1.
inner_iterations = 1
if solver.progress_fn is not None:
jax.debug.callback(
solver.progress_fn,
(iteration, inner_iterations, solver.max_iterations, state)
)
return new_state
state = fixed_point_loop.fixpoint_iter(
cond_fn=cond_fn,
body_fn=body_fn,
min_iterations=solver.min_iterations,
max_iterations=solver.max_iterations,
inner_iterations=1,
constants=solver,
state=solver.init_state(prob, init)
)
return solver.output_from_state(state)