from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import numpy as np
from scipy.special import ive

from ott import utils
from ott.geometry import geometry
from ott.math import utils as mu

__all__ = ["Geodesic"]

Array_g = Union[jnp.ndarray, jesp.BCOO]

[docs]
@jax.tree_util.register_pytree_node_class
class Geodesic(geometry.Geometry):
r"""Graph distance approximation using heat kernel :cite:huguet:2023.

.. note::
This constructor is not meant to be called by the user,
please use the :meth:from_graph method instead.

Approximates the heat kernel using
Chebyshev polynomials <https://en.wikipedia.org/wiki/Chebyshev_polynomials>_
of the first kind of max order order, which for small t
approximates the geodesic exponential kernel :math:e^{\frac{-d(x, y)^2}{t}}.

Args:
scaled_laplacian: The Laplacian scaled by the largest eigenvalue.
eigval: The largest eigenvalue of the Laplacian.
chebyshev_coeffs: Coefficients of the Chebyshev polynomials.
t: Time parameter for the heat kernel.
kwargs: Keyword arguments for :class:~ott.geometry.geometry.Geometry.
"""

def __init__(
self,
scaled_laplacian: Array_g,
eigval: jnp.ndarray,
chebyshev_coeffs: jnp.ndarray,
t: float = 1e-3,
**kwargs: Any
):
super().__init__(epsilon=1.0, **kwargs)
self.scaled_laplacian = scaled_laplacian
self.eigval = eigval
self.chebyshev_coeffs = chebyshev_coeffs
self.t = t

[docs]
@classmethod
def from_graph(
cls,
G: Array_g,
t: Optional[float] = 1e-3,
eigval: Optional[jnp.ndarray] = None,
order: int = 100,
directed: bool = False,
normalize: bool = False,
rng: Optional[jax.Array] = None,
**kwargs: Any
) -> "Geodesic":
r"""Construct a Geodesic geometry from an adjacency matrix.

Args:
t: Time parameter for approximating the geodesic exponential kernel.
If None, it defaults to :math:\frac{1}{|E|} \sum_{(u, v) \in E}
\text{weight}(u, v) :cite:crane:13. In this case, the graph
must be specified and the edge weights are assumed to be positive.
eigval: Largest eigenvalue of the Laplacian. If :obj:None, it's
computed using :func:jax.experimental.sparse.linalg.lobpcg_standard.
order: Max order of Chebyshev polynomials.
directed: Whether the graph is directed. If :obj:True, it's made
undirected as :math:G + G^T. This parameter is ignored when passing
the Laplacian directly, assumed to be symmetric.
normalize: Whether to normalize the Laplacian as
:math:L^{sym} = \left(D^+\right)^{\frac{1}{2}} L
\left(D^+\right)^{\frac{1}{2}}, where :math:L is the
non-normalized Laplacian and :math:D is the degree matrix.
rng: Random key used when computing the largest eigenvalue.
kwargs: Keyword arguments for :class:~ott.geometry.geodesic.Geodesic.

Returns:
The Geodesic geometry.
"""
assert G.shape[0] == G.shape[1], G.shape
rng = utils.default_prng_key(rng)

if directed:
G = G + G.T
if t is None:
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2

if isinstance(G, jesp.BCOO):
laplacian = compute_sparse_laplacian(G, normalize)
else:
laplacian = compute_dense_laplacian(G, normalize)

if eigval is None:
eigval = compute_largest_eigenvalue(laplacian, rng)

scaled_laplacian, eigval = jax.lax.cond((eigval > 2.0), lambda l:
(2.0 * l / eigval, 2.0), lambda l:
(l, eigval), laplacian)

# compute the coeffs of the Chebyshev pols approx using Bessel funcs
chebyshev_coeffs = compute_chebychev_coeff_all(
0.5 * eigval, t, order, laplacian.dtype
)

return cls(
scaled_laplacian=scaled_laplacian,
eigval=eigval,
chebyshev_coeffs=chebyshev_coeffs,
t=t,
**kwargs
)

[docs]
def apply_kernel(
self,
scaling: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
r"""Apply :attr:kernel_matrix on positive scaling vector.

Args:
scaling: Scaling to apply the kernel to.
eps: passed for consistency, not used yet.
axis: passed for consistency, not used yet.

Returns:
Kernel applied to scaling.
"""
return expm_multiply(
self.scaled_laplacian, scaling, self.chebyshev_coeffs, 0.5 * self.eigval
)

@property
def kernel_matrix(self) -> jnp.ndarray:  # noqa: D102
n, _ = self.shape
kernel = self.apply_kernel(jnp.eye(n))
return jax.lax.cond(
jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x,
lambda x: (x + x.T) / 2.0, kernel
)

@property
def cost_matrix(self) -> jnp.ndarray:  # noqa: D102
# Calculate the cost matrix using the formula (5) from the main reference
return -4.0 * self.t * mu.safe_log(self.kernel_matrix)

@property
def shape(self) -> Tuple[int, int]:  # noqa: D102
return self.scaled_laplacian.shape

@property
def is_symmetric(self) -> bool:  # noqa: D102
return True

@property
def dtype(self) -> jnp.dtype:  # noqa: D102
return self.scaled_laplacian.dtype

[docs]
def transport_from_potentials(
self, f: jnp.ndarray, g: jnp.ndarray
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

[docs]
def apply_transport_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
vec: jnp.ndarray,
axis: int = 0
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

[docs]
def marginal_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
axis: int = 0,
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:  # noqa: D102
return [
self.scaled_laplacian,
self.eigval,
self.chebyshev_coeffs,
self.t,
], {}

@classmethod
def tree_unflatten(  # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "Geodesic":
return cls(*children, **aux_data)

def normalize_laplacian(laplacian: Array_g, degree: jnp.ndarray) -> Array_g:
inv_sqrt_deg = jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
return inv_sqrt_deg[:, None] * laplacian * inv_sqrt_deg[None, :]

def compute_dense_laplacian(
G: jnp.ndarray, normalize: bool = False
) -> jnp.ndarray:
degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G
if normalize:
laplacian = normalize_laplacian(laplacian, degree)
return laplacian

def compute_sparse_laplacian(
G: jesp.BCOO, normalize: bool = False
) -> jesp.BCOO:
n, _ = G.shape
# making sure allocated indices has same dtype
# on different devices int32 vs int64 can cause issues
indices_dtype = G.indices.dtype
data_degree, ixs = G.sum(1).todense(), jnp.arange(n, dtype=indices_dtype)
degree = jesp.BCOO(
(data_degree, jnp.c_[ixs, ixs]),
shape=(n, n),
)
laplacian = degree - G
if normalize:
laplacian = normalize_laplacian(laplacian, data_degree)
return laplacian

def compute_largest_eigenvalue(
laplacian_matrix: jnp.ndarray,
rng: jax.Array,
) -> float:
# Compute the largest eigenvalue of the Laplacian matrix.
n = laplacian_matrix.shape[0]
# Generate random initial directions for eigenvalue computation
initial_dirs = jax.random.normal(rng, (n, 1))

# Create a sparse matrix-vector product function using sparsify
# This function multiplies the sparse laplacian_matrix with a vector
lapl_vector_product = jesp.sparsify(lambda v: laplacian_matrix @ v)

# Compute eigenvalues using the sparse matrix-vector product
eigvals, _, _ = jesp.linalg.lobpcg_standard(
lapl_vector_product,
initial_dirs,
)
return eigvals[0]

def expm_multiply(
L: Array_g, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
) -> jnp.ndarray:

def body(carry, c):
T0, T1, Y = carry
T2 = (2.0 / eigval) * L @ T1 - 2.0 * T1 - T0
Y = Y + c * T2
return (T1, T2, Y), None

T0 = X
Y = 0.5 * coeff[0] * T0
T1 = (1.0 / eigval) * L @ X - T0
Y = Y + coeff[1] * T1

initial_state = (T0, T1, Y)
(_, _, Y), _ = jax.lax.scan(body, initial_state, coeff[2:])
return Y

def compute_chebychev_coeff_all(
eigval: float, tau: float, K: int, dtype: np.dtype
) -> jnp.ndarray:
"""Jax wrapper to compute the K+1 Chebychev coefficients."""
result_shape_dtype = jax.ShapeDtypeStruct(
shape=(K + 1,),
dtype=dtype,
)

chebychev_coeff = lambda eigval, tau, K: (
2.0 * ive(np.arange(0, K + 1), -tau * eigval)
).astype(dtype)

return jax.pure_callback(chebychev_coeff, result_shape_dtype, eigval, tau, K)