Source code for ott.neural.methods.flows.dynamics

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc

import jax
import jax.numpy as jnp

__all__ = [
    "BaseFlow",
    "StraightFlow",
    "ConstantNoiseFlow",
    "BrownianBridge",
]


[docs] class BaseFlow(abc.ABC): """Base class for all flows. Args: sigma: Noise used for computing time-dependent noise schedule. """ def __init__(self, sigma: float): self.sigma = sigma
[docs] @abc.abstractmethod def compute_mu_t( self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Compute the mean of the probability path. Compute the mean of the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: t: Time :math:`t` of shape ``[batch, 1]``. src: Sample from the source distribution of shape ``[batch, ...]``. tgt: Sample from the target distribution of shape ``[batch, ...]``. """
[docs] @abc.abstractmethod def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: """Compute the standard deviation of the probability path at time :math:`t`. Args: t: Time :math:`t` of shape ``[batch, 1]``. Returns: Standard deviation of the probability path at time :math:`t`. """
[docs] @abc.abstractmethod def compute_ut( self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Evaluate the conditional vector field. Evaluate the conditional vector field defined between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: t: Time :math:`t` of shape ``[batch, 1]``. src: Sample from the source distribution of shape ``[batch, ...]``. tgt: Sample from the target distribution of shape ``[batch, ...]``. Returns: Conditional vector field evaluated at time :math:`t`. """
[docs] def compute_xt( self, rng: jax.Array, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Sample from the probability path. Sample from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: rng: Random number generator. t: Time :math:`t` of shape ``[batch, 1]``. src: Sample from the source distribution of shape ``[batch, ...]``. tgt: Sample from the target distribution of shape ``[batch, ...]``. Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. """ noise = jax.random.normal(rng, shape=src.shape) mu_t = self.compute_mu_t(t, src, tgt) sigma_t = self.compute_sigma_t(t) return mu_t + sigma_t * noise
[docs] class StraightFlow(BaseFlow, abc.ABC): """Base class for flows with straight paths. Args: sigma: Noise used for computing time-dependent noise schedule. """
[docs] def compute_mu_t( # noqa: D102 self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: return (1.0 - t) * src + t * tgt
[docs] def compute_ut( # noqa: D102 self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: del t return tgt - src
[docs] class ConstantNoiseFlow(StraightFlow): r"""Flow with straight paths and constant flow noise :math:`\sigma`. Args: sigma: Constant noise used for computing time-independent noise schedule. """
[docs] def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. Args: t: Time :math:`t` of shape ``[batch, 1]``. Returns: Constant, time-independent standard deviation :math:`\sigma`. """ return jnp.full_like(t, fill_value=self.sigma)
[docs] class BrownianBridge(StraightFlow): r"""Brownian Bridge. Sampler for sampling noise implicitly defined by a Schrödinger Bridge problem with parameter :math:`\sigma` such that :math:`\sigma_t = \sigma \cdot \sqrt{t \cdot (1 - t)}` :cite:`tong:23`. Args: sigma: Noise used for computing time-dependent noise schedule. """
[docs] def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. Args: t: Time :math:`t` of shape ``[batch, 1]``. Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. """ return self.sigma * jnp.sqrt(t * (1.0 - t))