# Source code for ott.tools.gaussian_mixture.gaussian

# Copyright OTT-JAX
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import math
from typing import Optional, Union

import jax
import jax.numpy as jnp

from ott.tools.gaussian_mixture import scale_tril

__all__ = ["Gaussian"]

LOG2PI = math.log(2.0 * math.pi)

[docs]
@jax.tree_util.register_pytree_node_class
class Gaussian:
"""Normal distribution."""

def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL):
self._loc = loc
self._scale = scale

[docs]
@classmethod
def from_samples(
cls,
points: jnp.ndarray,
weights: Optional[jnp.ndarray] = None
) -> "Gaussian":
"""Construct a Gaussian from weighted samples.

Unbiased, weighted covariance formula from GSL
<https://www.gnu.org/software/gsl/doc/html/statistics.html#weighted-samples>_.

Args:
points: [n x d] array of samples
weights: [n] array of weights

Returns:
Gaussian.
"""
n = points.shape[0]
if weights is None:
weights = jnp.ones(n) / n

mean = weights.dot(points)
centered_x = (points - mean)
scaled_centered_x = centered_x * weights.reshape(-1, 1)
cov = scaled_centered_x.T.dot(centered_x) / (1 - weights.dot(weights))
return cls.from_mean_and_cov(mean=mean, cov=cov)

[docs]
@classmethod
def from_random(
cls,
rng: jax.Array,
n_dimensions: int,
stdev_mean: float = 0.1,
stdev_cov: float = 0.1,
ridge: Union[float, jnp.ndarray] = 0,
) -> "Gaussian":
"""Construct a random Gaussian.

Args:
rng: jax.random key
n_dimensions: desired covariance dimensions
stdev_mean: standard deviation of location and log eigenvalues
(means for both are 0)
stdev_cov: standard deviated of the covariance
ridge: Offset for means.

Returns:
A random Gaussian.
"""
rng, subrng0, subrng1 = jax.random.split(rng, num=3)
loc = jax.random.normal(subrng0, shape=(n_dimensions,)) * stdev_mean + ridge
scale = scale_tril.ScaleTriL.from_random(
subrng1, n_dimensions=n_dimensions, stdev=stdev_cov
)
return cls(loc=loc, scale=scale)

[docs]
@classmethod
def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> "Gaussian":
"""Construct a Gaussian from a mean and covariance."""
scale = scale_tril.ScaleTriL.from_covariance(cov)
return cls(loc=mean, scale=scale)

@property
def loc(self) -> jnp.ndarray:
"""Mean of the Gaussian."""
return self._loc

@property
def scale(self) -> scale_tril.ScaleTriL:
"""Scale of the Gaussian."""
return self._scale

@property
def n_dimensions(self) -> int:
"""Dimensionality of the Gaussian."""
return self.loc.shape[-1]

[docs]
def covariance(self) -> jnp.ndarray:
"""Covariance of the Gaussian."""
return self.scale.covariance()

[docs]
def to_z(self, x: jnp.ndarray) -> jnp.ndarray:
r"""Transform :math:x to :math:z = \frac{x - loc}{scale}."""
return self.scale.centered_to_z(x_centered=x - self.loc)

[docs]
def from_z(self, z: jnp.ndarray) -> jnp.ndarray:
r"""Transform :math:z to :math:x = loc + scale \cdot z."""
return self.scale.z_to_centered(z=z) + self.loc

[docs]
def log_prob(
self,
x: jnp.ndarray,  # (?, d)
) -> jnp.ndarray:  # (?, d)
"""Log probability for a Gaussian with a diagonal covariance."""
d = x.shape[-1]
z = self.to_z(x)
log_det = self.scale.log_det_covariance()
return (
-0.5 * (d * LOG2PI + log_det[None] + jnp.sum(z ** 2, axis=-1))
)  # (?, k)

[docs]
def sample(self, rng: jax.Array, size: int) -> jnp.ndarray:
"""Generate samples from the distribution."""
std_samples_t = jax.random.normal(rng, shape=(self.n_dimensions, size))
return self.loc[None] + (
jnp.swapaxes(
jnp.matmul(self.scale.cholesky(), std_samples_t),
axis1=-2,
axis2=-1
)
)

[docs]
def w2_dist(self, other: "Gaussian") -> jnp.ndarray:
r"""Wasserstein distance :math:W_2^2 to another Gaussian.

.. math::

W_2^2 = ||\mu_0-\mu_1||^2 +
\text{trace} ( (\Lambda_0^\frac{1}{2} - \Lambda_1^\frac{1}{2})^2 )

Args:
other: other Gaussian

Returns:
The :math:W_2^2 distance between self and other
"""
delta_mean = jnp.sum((self.loc - other.loc) ** 2, axis=-1)
delta_sigma = self.scale.w2_dist(other.scale)
return delta_mean + delta_sigma

[docs]
def f_potential(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray:
"""Optimal potential for W2 distance between Gaussians. Evaluated on points.

Args:
dest: Gaussian object
points: samples

Returns:
Dual potential, f
"""
scale_matrix = self.scale.gaussian_map(dest_scale=dest.scale)
centered_x = points - self.loc
scaled_x = (scale_matrix @ centered_x.T)

@jax.vmap
def batch_inner_product(x, y):
return x.dot(y)

return (
0.5 * batch_inner_product(points, points) -
0.5 * batch_inner_product(centered_x, scaled_x.T) -
points.dot(dest.loc)
)

[docs]
def transport(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray:
"""Transport points according to map between two Gaussian measures.

Args:
dest: Gaussian object
points: samples

Returns:
Transported samples
"""
return self.scale.transport(
dest_scale=dest.scale, points=points - self.loc[None]
) + dest.loc[None]

def tree_flatten(self):  # noqa: D102
children = (self.loc, self.scale)
aux_data = {}
return children, aux_data

@classmethod
def tree_unflatten(cls, aux_data, children):  # noqa: D102
return cls(*children, **aux_data)

def __hash__(self):
return jax.tree_util.tree_flatten(self).__hash__()

def __eq__(self, other):
return jax.tree_util.tree_flatten(self) == jax.tree_util.tree_flatten(other)