Source code for ott.geometry.graph

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Literal, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu

from ott.geometry import geometry
from ott.math import fixed_point_loop
from ott.math import utils as mu

__all__ = ["Graph"]


[docs] @jtu.register_pytree_node_class class Graph(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`heitz:21,crane:13`. Approximates the heat kernel for large ``n_steps``, which for small ``t`` approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. Args: laplacian: Symmetric graph Laplacian. The check for symmetry is **NOT** performed. See also :meth:`from_graph`. n_steps: Maximum number of steps used to approximate the heat kernel. numerical_scheme: Numerical scheme used to solve the heat diffusion. normalize: Whether to normalize the Laplacian as :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the non-normalized Laplacian and :math:`D` is the degree matrix. tol: Relative tolerance with respect to the Hilbert metric, see :cite:`peyre:19`, Remark 4.12. Used when iteratively updating scalings. If negative, this option is ignored and only ``n_steps`` is used. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ def __init__( self, laplacian: jnp.ndarray, t: float = 1e-3, n_steps: int = 100, numerical_scheme: Literal["backward_euler", "crank_nicolson"] = "backward_euler", tol: float = -1.0, **kwargs: Any ): super().__init__(epsilon=1.0, **kwargs) self.laplacian = laplacian self.t = t self.n_steps = n_steps self.numerical_scheme = numerical_scheme self.tol = tol
[docs] @classmethod def from_graph( cls, G: jnp.ndarray, t: Optional[float] = 1e-3, directed: bool = False, normalize: bool = False, **kwargs: Any ) -> "Graph": r"""Construct :class:`~ott.geometry.graph.Graph` from an adjacency matrix. Args: G: Adjacency matrix. t: Constant used when approximating the geodesic exponential kernel. If `None`, use :math:`\frac{1}{|E|} \sum_{(u, v) \in E} weight(u, v)` :cite:`crane:13`. In this case, the ``graph`` must be specified and the edge weights are all assumed to be positive. directed: Whether the ``graph`` is directed. If not, it will be made undirected as :math:`G + G^T`. This parameter is ignored when directly passing the Laplacian, which is assumed to be symmetric. normalize: Whether to normalize the Laplacian as :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the non-normalized Laplacian and :math:`D` is the degree matrix. kwargs: Keyword arguments for :class:`~ott.geometry.graph.Graph`. Returns: The graph geometry. """ assert G.shape[0] == G.shape[1], G.shape if directed: G = G + G.T degree = jnp.sum(G, axis=1) laplacian = jnp.diag(degree) - G if normalize: inv_sqrt_deg = jnp.diag( jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) ) laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2 return cls(laplacian, t=t, **kwargs)
[docs] def apply_kernel( self, vec: jnp.ndarray, eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: r"""Apply :attr:`kernel_matrix` on a positive vector. Args: vec: Vector to which the kernel is applied. eps: passed for consistency, not used yet. axis: passed for consistency, not used yet. Returns: Kernel applied to ``scaling``. """ def conf_fn( iteration: int, consts: Tuple[jnp.ndarray, Optional[jnp.ndarray]], old_new: Tuple[jnp.ndarray, jnp.ndarray] ) -> bool: del iteration, consts x_old, x_new = old_new x_old, x_new = mu.safe_log(x_old), mu.safe_log(x_new) # center x_old, x_new = x_old - jnp.nanmax(x_old), x_new - jnp.nanmax(x_new) # Hilbert metric, see Remark 4.12 in `Computational Optimal Transport` f = x_new - x_old return (jnp.nanmax(f) - jnp.nanmin(f)) > self.tol def body_fn( iteration: int, consts: Tuple[jnp.ndarray, Optional[jnp.ndarray]], old_new: Tuple[jnp.ndarray, jnp.ndarray], compute_errors: bool ) -> Tuple[jnp.ndarray, jnp.ndarray]: del iteration, compute_errors L, scaled_lap = consts _, b = old_new if self.numerical_scheme == "crank_nicolson": # below is a preferred way of specifying the update (albeit more FLOPS), # as CSR/CSC/COO matrices don't support adding a diagonal matrix now: # b' = (2 * I - M) @ b = (2 * I - (I + c * L)) @ b = (I - c * L) @ b b = b - scaled_lap @ b return b, jsp.linalg.solve_triangular(L, b, lower=True) # eps we cannot use since it would require a re-solve # axis we can ignore since the matrix is symmetric del eps, axis force_scan = self.tol < 0.0 fixpoint_fn = ( fixed_point_loop.fixpoint_iter if force_scan else fixed_point_loop.fixpoint_iter_backprop ) state = (jnp.full_like(vec, jnp.nan), vec) L = jsp.linalg.cholesky(self._M, lower=True) if self.numerical_scheme == "crank_nicolson": constants = L, self._scaled_laplacian else: constants = L, None return fixpoint_fn( cond_fn=(lambda *_, **__: True) if force_scan else conf_fn, body_fn=body_fn, min_iterations=self.n_steps if force_scan else 1, max_iterations=self.n_steps, inner_iterations=1, constants=constants, state=state, )[1]
@property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) # Symmetrize the kernel if needed. Numerical imprecision # happens when `numerical_scheme='backward_euler'` and small `t` return jax.lax.cond( jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x, lambda x: (x + x.T) / 2.0, kernel ) @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 return -self.t * mu.safe_log(self.kernel_matrix) @property def _scale(self) -> float: """Constant used to scale the Laplacian.""" if self.numerical_scheme == "backward_euler": return self.t / (4.0 * self.n_steps) if self.numerical_scheme == "crank_nicolson": return self.t / (2.0 * self.n_steps) raise NotImplementedError( f"Numerical scheme `{self.numerical_scheme}` is not implemented." ) @property def _scaled_laplacian(self) -> jnp.ndarray: """Laplacian scaled by a constant, depending on the numerical scheme.""" return self._scale * self.laplacian @property def _M(self) -> jnp.ndarray: n, _ = self.shape return self._scaled_laplacian + jnp.eye(n) @property def shape(self) -> Tuple[int, int]: # noqa: D102 return self.laplacian.shape @property def is_symmetric(self) -> bool: # noqa: D102 return True @property def dtype(self) -> jnp.dtype: # noqa: D102 return self.laplacian.dtype
[docs] def transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
[docs] def apply_transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int = 0 ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
[docs] def marginal_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0, ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.laplacian, self.t], { "n_steps": self.n_steps, "numerical_scheme": self.numerical_scheme, "tol": self.tol, } @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "Graph": return cls(*children, **aux_data)