Source code for ott.geometry.geodesic

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from scipy.special import ive

from ott import utils
from ott.geometry import geometry
from ott.math import utils as mu

__all__ = ["Geodesic"]

Array_g = Union[jnp.ndarray, jesp.BCOO]


[docs] @jtu.register_pytree_node_class class Geodesic(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`huguet:2023`. .. note:: This constructor is not meant to be called by the user, please use the :meth:`from_graph` method instead. Approximates the heat kernel using `Chebyshev polynomials <https://en.wikipedia.org/wiki/Chebyshev_polynomials>`_ of the first kind of max order ``order``, which for small ``t`` approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. Args: scaled_laplacian: The Laplacian scaled by the largest eigenvalue. eigval: The largest eigenvalue of the Laplacian. chebyshev_coeffs: Coefficients of the Chebyshev polynomials. t: Time parameter for the heat kernel. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ def __init__( self, scaled_laplacian: Array_g, eigval: jnp.ndarray, chebyshev_coeffs: jnp.ndarray, t: float = 1e-3, **kwargs: Any ): super().__init__(epsilon=1.0, **kwargs) self.scaled_laplacian = scaled_laplacian self.eigval = eigval self.chebyshev_coeffs = chebyshev_coeffs self.t = t
[docs] @classmethod def from_graph( cls, G: Array_g, t: Optional[float] = 1e-3, eigval: Optional[jnp.ndarray] = None, order: int = 100, directed: bool = False, normalize: bool = False, rng: Optional[jax.Array] = None, **kwargs: Any ) -> "Geodesic": r"""Construct a Geodesic geometry from an adjacency matrix. Args: G: Adjacency matrix. t: Time parameter for approximating the geodesic exponential kernel. If `None`, it defaults to :math:`\frac{1}{|E|} \sum_{(u, v) \in E} \text{weight}(u, v)` :cite:`crane:13`. In this case, the ``graph`` must be specified and the edge weights are assumed to be positive. eigval: Largest eigenvalue of the Laplacian. If :obj:`None`, it's computed using :func:`jax.experimental.sparse.linalg.lobpcg_standard`. order: Max order of Chebyshev polynomials. directed: Whether the ``graph`` is directed. If :obj:`True`, it's made undirected as :math:`G + G^T`. This parameter is ignored when passing the Laplacian directly, assumed to be symmetric. normalize: Whether to normalize the Laplacian as :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the non-normalized Laplacian and :math:`D` is the degree matrix. rng: Random key used when computing the largest eigenvalue. kwargs: Keyword arguments for :class:`~ott.geometry.geodesic.Geodesic`. Returns: The Geodesic geometry. """ assert G.shape[0] == G.shape[1], G.shape rng = utils.default_prng_key(rng) if directed: G = G + G.T if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2 if isinstance(G, jesp.BCOO): laplacian = compute_sparse_laplacian(G, normalize) else: laplacian = compute_dense_laplacian(G, normalize) if eigval is None: eigval = compute_largest_eigenvalue(laplacian, rng) scaled_laplacian, eigval = jax.lax.cond((eigval > 2.0), lambda l: (2.0 * l / eigval, 2.0), lambda l: (l, eigval), laplacian) # compute the coeffs of the Chebyshev pols approx using Bessel funcs chebyshev_coeffs = compute_chebychev_coeff_all( 0.5 * eigval, t, order, laplacian.dtype ) return cls( scaled_laplacian=scaled_laplacian, eigval=eigval, chebyshev_coeffs=chebyshev_coeffs, t=t, **kwargs )
[docs] def apply_kernel( self, vec: jnp.ndarray, eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: r"""Apply :attr:`kernel_matrix` on a positive vector. Args: vec: Vector to which the kernel is applied. eps: passed for consistency, not used yet. axis: passed for consistency, not used yet. Returns: Kernel applied to ``scaling``. """ return expm_multiply( self.scaled_laplacian, vec, self.chebyshev_coeffs, 0.5 * self.eigval )
@property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) return jax.lax.cond( jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x, lambda x: (x + x.T) / 2.0, kernel ) @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 # Calculate the cost matrix using the formula (5) from the main reference return -4.0 * self.t * mu.safe_log(self.kernel_matrix) @property def shape(self) -> Tuple[int, int]: # noqa: D102 return self.scaled_laplacian.shape @property def is_symmetric(self) -> bool: # noqa: D102 return True @property def dtype(self) -> jnp.dtype: # noqa: D102 return self.scaled_laplacian.dtype
[docs] def transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
[docs] def apply_transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int = 0 ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
[docs] def marginal_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0, ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [ self.scaled_laplacian, self.eigval, self.chebyshev_coeffs, self.t, ], {} @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "Geodesic": return cls(*children, **aux_data)
def normalize_laplacian(laplacian: Array_g, degree: jnp.ndarray) -> Array_g: inv_sqrt_deg = jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) return inv_sqrt_deg[:, None] * laplacian * inv_sqrt_deg[None, :] def compute_dense_laplacian( G: jnp.ndarray, normalize: bool = False ) -> jnp.ndarray: degree = jnp.sum(G, axis=1) laplacian = jnp.diag(degree) - G if normalize: laplacian = normalize_laplacian(laplacian, degree) return laplacian def compute_sparse_laplacian( G: jesp.BCOO, normalize: bool = False ) -> jesp.BCOO: n, _ = G.shape # making sure allocated indices has same dtype # on different devices int32 vs int64 can cause issues indices_dtype = G.indices.dtype data_degree, ixs = G.sum(1).todense(), jnp.arange(n, dtype=indices_dtype) degree = jesp.BCOO( (data_degree, jnp.c_[ixs, ixs]), shape=(n, n), ) laplacian = degree - G if normalize: laplacian = normalize_laplacian(laplacian, data_degree) return laplacian def compute_largest_eigenvalue( laplacian_matrix: jnp.ndarray, rng: jax.Array, ) -> float: # Compute the largest eigenvalue of the Laplacian matrix. n = laplacian_matrix.shape[0] # Generate random initial directions for eigenvalue computation initial_dirs = jax.random.normal(rng, (n, 1)) # Create a sparse matrix-vector product function using sparsify # This function multiplies the sparse laplacian_matrix with a vector lapl_vector_product = jesp.sparsify(lambda v: laplacian_matrix @ v) # Compute eigenvalues using the sparse matrix-vector product eigvals, _, _ = jesp.linalg.lobpcg_standard( lapl_vector_product, initial_dirs, ) return eigvals[0] def expm_multiply( L: Array_g, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float ) -> jnp.ndarray: def body(carry, c): T0, T1, Y = carry T2 = (2.0 / eigval) * L @ T1 - 2.0 * T1 - T0 Y = Y + c * T2 return (T1, T2, Y), None T0 = X Y = 0.5 * coeff[0] * T0 T1 = (1.0 / eigval) * L @ X - T0 Y = Y + coeff[1] * T1 initial_state = (T0, T1, Y) (_, _, Y), _ = jax.lax.scan(body, initial_state, coeff[2:]) return Y def compute_chebychev_coeff_all( eigval: float, tau: float, K: int, dtype: np.dtype ) -> jnp.ndarray: """Jax wrapper to compute the K+1 Chebychev coefficients.""" result_shape_dtype = jax.ShapeDtypeStruct( shape=(K + 1,), dtype=dtype, ) chebychev_coeff = lambda eigval, tau, K: ( 2.0 * ive(np.arange(0, K + 1), -tau * eigval) ).astype(dtype) return jax.pure_callback(chebychev_coeff, result_shape_dtype, eigval, tau, K)