Source code for ott.solvers.quadratic.gromov_wasserstein

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
    Any,
    Callable,
    Dict,
    Literal,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import geometry
from ott.initializers.quadratic import initializers as quad_initializers
from ott.math import fixed_point_loop
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers import was_solver
from ott.solvers.linear import sinkhorn, sinkhorn_lr

__all__ = ["GromovWasserstein", "GWOutput"]

LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]

ProgressCallbackFn = Callable[
    [Tuple[np.ndarray, np.ndarray, np.ndarray, "GWState"]], None]


[docs] class GWOutput(NamedTuple): """Holds the output of the Gromov-Wasserstein solver. Args: costs: Holds the sequence of regularized GW costs seen through the outer loop of the solver. linear_convergence: Holds the sequence of bool convergence flags of the inner Sinkhorn iterations. converged: Convergence flag for the outer GW iterations. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. linear_state: State used to solve and store solutions to the local linearization of GW. geom: The geometry underlying the local linearization. old_transport_mass: Holds total mass of transport at previous iteration. """ costs: Optional[jnp.ndarray] = None linear_convergence: Optional[jnp.ndarray] = None converged: bool = False errors: Optional[jnp.ndarray] = None linear_state: Optional[LinearOutput] = None geom: Optional[geometry.Geometry] = None # Intermediate values. old_transport_mass: float = 1.0
[docs] def set(self, **kwargs: Any) -> "GWOutput": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs)
@property def matrix(self) -> jnp.ndarray: """Transport matrix.""" return self._rescale_factor * self.linear_state.matrix
[docs] def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply the transport to an array; axis=1 for its transpose.""" return self._rescale_factor * self.linear_state.apply(inputs, axis=axis)
@property def reg_gw_cost(self) -> float: """Regularized optimal transport cost of the linearization.""" return self.linear_state.reg_ot_cost @property def _rescale_factor(self) -> float: return jnp.sqrt(self.old_transport_mass / self.linear_state.transport_mass) @property def primal_cost(self) -> float: """Return transport cost of current linear OT solution at geometry.""" return self.linear_state.transport_cost_at_geom(other_geom=self.geom) @property def n_iters(self) -> int: # noqa: D102 if self.errors is None: return -1 return jnp.sum(self.errors[:, 0] != -1)
class GWState(NamedTuple): """State of the Gromov-Wasserstein solver. Attributes: costs: Holds the sequence of regularized GW costs seen through the outer loop of the solver. linear_convergence: Holds the sequence of bool convergence flags of the inner Sinkhorn iterations. linear_state: State used to solve and store solutions to the local linearization of GW. linear_pb: Local linearization of the quadratic GW problem. old_transport_mass: Intermediary value of the mass of the transport matrix. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. """ costs: jnp.ndarray linear_convergence: jnp.ndarray linear_state: LinearOutput linear_pb: linear_problem.LinearProblem old_transport_mass: float errors: Optional[jnp.ndarray] = None def set(self, **kwargs: Any) -> "GWState": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs) def update( # noqa: D102 self, iteration: int, linear_sol: LinearOutput, linear_pb: linear_problem.LinearProblem, store_errors: bool, old_transport_mass: float ) -> "GWState": costs = self.costs.at[iteration].set(linear_sol.reg_ot_cost) errors = None if store_errors and self.errors is not None: errors = self.errors.at[iteration, :].set(linear_sol.errors) linear_convergence = self.linear_convergence.at[iteration].set( linear_sol.converged ) return self.set( linear_state=linear_sol, linear_pb=linear_pb, costs=costs, linear_convergence=linear_convergence, errors=errors, old_transport_mass=old_transport_mass )
[docs] @jax.tree_util.register_pytree_node_class class GromovWasserstein(was_solver.WassersteinSolver): """Entropic Gromov-Wasserstein solver :cite:`peyre:16`. .. seealso:: Low-rank Gromov-Wasserstein :cite:`scetbon:23` is implemented in :class:`~ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein`. Args: linear_solver: Linear OT solver. epsilon: Entropic regularization. relative_epsilon: Whether to use relative epsilon in the linearized geometry. initializer: Quadratic initializer. If :obj:`None`, use :class:`~ott.initializers.quadratic.initializers.QuadraticInitializer`. warm_start: Whether to initialize Sinkhorn calls with the values from the previous iteration. progress_fn: callback function which gets called during the Gromov-Wasserstein iterations, so the user can display the error at each iteration, e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn` for a basic implementation. kwargs: Keyword arguments for :class:`~ott.solvers.was_solver.WassersteinSolver`. """ def __init__( self, linear_solver: sinkhorn.Sinkhorn, epsilon: float = 1.0, relative_epsilon: Optional[Literal["mean", "std"]] = None, initializer: Optional[quad_initializers.BaseQuadraticInitializer] = None, warm_start: bool = False, progress_fn: Optional[ProgressCallbackFn] = None, **kwargs: Any ): super().__init__(linear_solver, **kwargs) self.epsilon = epsilon self.relative_epsilon = relative_epsilon self.initializer = quad_initializers.QuadraticInitializer( ) if initializer is None else initializer self.warm_start = warm_start self.progress_fn = progress_fn def __call__( self, prob: quadratic_problem.QuadraticProblem, init: Optional[linear_problem.LinearProblem] = None, **kwargs: Any, ) -> GWOutput: """Run the Gromov-Wasserstein solver. Args: prob: Quadratic OT problem. init: Initial linearization of the quadratic problem. If :obj:`None`, use the initializer. kwargs: Keyword arguments for the initializer. Returns: The Gromov-Wasserstein output. """ if prob._is_low_rank_convertible: prob = prob.to_low_rank() if init is None: init = self.initializer( prob, epsilon=self.epsilon, relative_epsilon=self.relative_epsilon, **kwargs, ) out = iterations(self, prob, init) # TODO(lpapaxanthoos): remove stop_gradient when using backprop linearization = prob.update_linearization( jax.lax.stop_gradient(out.linear_state), epsilon=self.epsilon, old_transport_mass=jax.lax.stop_gradient(out.old_transport_mass), relative_epsilon=self.relative_epsilon, ) linear_state = out.linear_state.set_cost(linearization, True, True) iteration = jnp.sum(out.costs != -1) converged = jnp.logical_and( iteration < self.max_iterations, jnp.nanmean(out.linear_convergence) == 1.0 ) return out.set( linear_state=linear_state, geom=linearization.geom, converged=converged )
[docs] def init_state( self, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, ) -> GWState: """Initialize the state of the Gromov-Wasserstein iterations. Args: prob: Quadratic OT problem. init: Initial linearization of the quadratic problem. Returns: The initial Gromov-Wasserstein state. """ linear_state = self.linear_solver(init) num_iter = self.max_iterations transport_mass = prob.init_transport_mass() if self.store_inner_errors: errors = -jnp.ones((num_iter, self.linear_solver.outer_iterations)) else: errors = None return GWState( costs=-jnp.ones((num_iter,)), linear_convergence=jnp.full((num_iter,), fill_value=jnp.nan), linear_state=linear_state, linear_pb=init, old_transport_mass=transport_mass, errors=errors, )
[docs] def output_from_state( self, state: GWState, ) -> GWOutput: """Create an output from a loop state. Arguments: state: A GWState. Returns: A GWOutput. """ return GWOutput( costs=state.costs, linear_convergence=state.linear_convergence, errors=state.errors, linear_state=state.linear_state, geom=state.linear_pb.geom, old_transport_mass=state.old_transport_mass )
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["epsilon"] = self.epsilon aux_data["relative_epsilon"] = self.relative_epsilon aux_data["initializer"] = self.initializer aux_data["warm_start"] = self.warm_start aux_data["progress_fn"] = self.progress_fn return children, aux_data @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "GromovWasserstein": linear_solver, threshold = children return cls(linear_solver, threshold=threshold, **aux_data)
def iterations( solver: GromovWasserstein, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, ) -> GWOutput: """Jittable Gromov-Wasserstein outer loop.""" def cond_fn( iteration: int, solver: GromovWasserstein, state: GWState ) -> bool: return solver._continue(state, iteration) def body_fn( iteration: int, solver: GromovWasserstein, state: GWState, compute_error: bool ) -> GWState: del compute_error # always assumed true for the outer loop of GW lin_state = state.linear_state init = (lin_state.f, lin_state.g) if solver.warm_start else None linear_pb = prob.update_linearization( lin_state, solver.epsilon, state.old_transport_mass, relative_epsilon=solver.relative_epsilon, ) out = solver.linear_solver(linear_pb, init=init) old_transport_mass = jax.lax.stop_gradient( state.linear_state.transport_mass ) new_state = state.update( iteration, out, linear_pb, solver.store_inner_errors, old_transport_mass ) # Inner iterations is currently fixed to 1. inner_iterations = 1 if solver.progress_fn is not None: jax.debug.callback( solver.progress_fn, (iteration, inner_iterations, solver.max_iterations, state) ) return new_state state = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=solver.min_iterations, max_iterations=solver.max_iterations, inner_iterations=1, constants=solver, state=solver.init_state(prob, init) ) return solver.output_from_state(state)