# Copyright OTT-JAX
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Union

import jax
import jax.numpy as jnp

from import scale_tril

__all__ = ["Gaussian"]

LOG2PI = math.log(2.0 * math.pi)

[docs] @jax.tree_util.register_pytree_node_class class Gaussian: """Normal distribution.""" def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): self._loc = loc self._scale = scale
[docs] @classmethod def from_samples( cls, points: jnp.ndarray, weights: Optional[jnp.ndarray] = None ) -> "Gaussian": """Construct a Gaussian from weighted samples. Unbiased, weighted covariance formula from `GSL <>`_. Args: points: [n x d] array of samples weights: [n] array of weights Returns: Gaussian. """ n = points.shape[0] if weights is None: weights = jnp.ones(n) / n mean = centered_x = (points - mean) scaled_centered_x = centered_x * weights.reshape(-1, 1) cov = / (1 - return cls.from_mean_and_cov(mean=mean, cov=cov)
[docs] @classmethod def from_random( cls, rng: jax.Array, n_dimensions: int, stdev_mean: float = 0.1, stdev_cov: float = 0.1, ridge: Union[float, jnp.ndarray] = 0, ) -> "Gaussian": """Construct a random Gaussian. Args: rng: jax.random key n_dimensions: desired covariance dimensions stdev_mean: standard deviation of location and log eigenvalues (means for both are 0) stdev_cov: standard deviated of the covariance ridge: Offset for means. Returns: A random Gaussian. """ rng, subrng0, subrng1 = jax.random.split(rng, num=3) loc = jax.random.normal(subrng0, shape=(n_dimensions,)) * stdev_mean + ridge scale = scale_tril.ScaleTriL.from_random( subrng1, n_dimensions=n_dimensions, stdev=stdev_cov ) return cls(loc=loc, scale=scale)
[docs] @classmethod def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> "Gaussian": """Construct a Gaussian from a mean and covariance.""" scale = scale_tril.ScaleTriL.from_covariance(cov) return cls(loc=mean, scale=scale)
@property def loc(self) -> jnp.ndarray: """Mean of the Gaussian.""" return self._loc @property def scale(self) -> scale_tril.ScaleTriL: """Scale of the Gaussian.""" return self._scale @property def n_dimensions(self) -> int: """Dimensionality of the Gaussian.""" return self.loc.shape[-1]
[docs] def covariance(self) -> jnp.ndarray: """Covariance of the Gaussian.""" return self.scale.covariance()
[docs] def to_z(self, x: jnp.ndarray) -> jnp.ndarray: r"""Transform :math:`x` to :math:`z = \frac{x - loc}{scale}`.""" return self.scale.centered_to_z(x_centered=x - self.loc)
[docs] def from_z(self, z: jnp.ndarray) -> jnp.ndarray: r"""Transform :math:`z` to :math:`x = loc + scale \cdot z`.""" return self.scale.z_to_centered(z=z) + self.loc
[docs] def log_prob( self, x: jnp.ndarray, # (?, d) ) -> jnp.ndarray: # (?, d) """Log probability for a Gaussian with a diagonal covariance.""" d = x.shape[-1] z = self.to_z(x) log_det = self.scale.log_det_covariance() return ( -0.5 * (d * LOG2PI + log_det[None] + jnp.sum(z ** 2, axis=-1)) ) # (?, k)
[docs] def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: """Generate samples from the distribution.""" std_samples_t = jax.random.normal(rng, shape=(self.n_dimensions, size)) return self.loc[None] + ( jnp.swapaxes( jnp.matmul(self.scale.cholesky(), std_samples_t), axis1=-2, axis2=-1 ) )
[docs] def w2_dist(self, other: "Gaussian") -> jnp.ndarray: r"""Wasserstein distance :math:`W_2^2` to another Gaussian. .. math:: W_2^2 = ||\mu_0-\mu_1||^2 + \text{trace} ( (\Lambda_0^\frac{1}{2} - \Lambda_1^\frac{1}{2})^2 ) Args: other: other Gaussian Returns: The :math:`W_2^2` distance between self and other """ delta_mean = jnp.sum((self.loc - other.loc) ** 2, axis=-1) delta_sigma = self.scale.w2_dist(other.scale) return delta_mean + delta_sigma
[docs] def f_potential(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: """Optimal potential for W2 distance between Gaussians. Evaluated on points. Args: dest: Gaussian object points: samples Returns: Dual potential, f """ scale_matrix = self.scale.gaussian_map(dest_scale=dest.scale) centered_x = points - self.loc scaled_x = (scale_matrix @ centered_x.T) @jax.vmap def batch_inner_product(x, y): return return ( 0.5 * batch_inner_product(points, points) - 0.5 * batch_inner_product(centered_x, scaled_x.T) - )
[docs] def transport(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: """Transport points according to map between two Gaussian measures. Args: dest: Gaussian object points: samples Returns: Transported samples """ return self.scale.transport( dest_scale=dest.scale, points=points - self.loc[None] ) + dest.loc[None]
def tree_flatten(self): # noqa: D102 children = (self.loc, self.scale) aux_data = {} return children, aux_data @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __hash__(self): return jax.tree_util.tree_flatten(self).__hash__() def __eq__(self, other): return jax.tree_util.tree_flatten(self) == jax.tree_util.tree_flatten(other)