Source code for ott.neural.networks.velocity_field

# Copyright OTT-JAX
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Sequence

import jax
import jax.numpy as jnp

import optax
from flax import linen as nn
from import train_state

from ott.neural.networks.layers import time_encoder

__all__ = ["VelocityField"]

[docs] class VelocityField(nn.Module): r"""Neural vector field. This class learns a map :math:`v: \mathbb{R}\times \mathbb{R}^d \rightarrow \mathbb{R}^d` solving the ODE :math:`\frac{dx}{dt} = v(t, x)`. Given a source distribution at time :math:`t_0`, the velocity field can be used to transport the source distribution given at :math:`t_0` to a target distribution given at :math:`t_1` by integrating :math:`v(t, x)` from :math:`t=t_0` to :math:`t=t_1`. Args: hidden_dims: Dimensionality of the embedding of the data. output_dims: Dimensionality of the embedding of the output. condition_dims: Dimensionality of the embedding of the condition. If :obj:`None`, the velocity field has no conditions. time_dims: Dimensionality of the time embedding. If :obj:`None`, ``hidden_dims`` is used. time_encoder: Time encoder for the velocity field. act_fn: Activation function. """ hidden_dims: Sequence[int] output_dims: Sequence[int] condition_dims: Optional[Sequence[int]] = None time_dims: Optional[Sequence[int]] = None time_encoder: Callable[[jnp.ndarray], jnp.ndarray] = time_encoder.cyclical_time_encoder act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu @nn.compact def __call__( self, t: jnp.ndarray, x: jnp.ndarray, condition: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Forward pass through the neural vector field. Args: t: Time of shape ``[batch, 1]``. x: Data of shape ``[batch, ...]``. condition: Conditioning vector of shape ``[batch, ...]``. Returns: Output of the neural vector field of shape ``[batch, output_dim]``. """ time_dims = self.hidden_dims if self.time_dims is None else self.time_dims t = self.time_encoder(t) for time_dim in time_dims: t = self.act_fn(nn.Dense(time_dim)(t)) for hidden_dim in self.hidden_dims: x = self.act_fn(nn.Dense(hidden_dim)(x)) if self.condition_dims is not None: assert condition is not None, "No condition was passed." for cond_dim in self.condition_dims: condition = self.act_fn(nn.Dense(cond_dim)(condition)) feats = jnp.concatenate([t, x, condition], axis=-1) else: feats = jnp.concatenate([t, x], axis=-1) for output_dim in self.output_dims[:-1]: feats = self.act_fn(nn.Dense(output_dim)(feats)) # no activation function for the final layer return nn.Dense(self.output_dims[-1])(feats)
[docs] def create_train_state( self, rng: jax.Array, optimizer: optax.OptState, input_dim: int, condition_dim: Optional[int] = None, ) -> train_state.TrainState: """Create the training state. Args: rng: Random number generator. optimizer: Optimizer. input_dim: Dimensionality of the velocity field. condition_dim: Dimensionality of the condition of the velocity field. Returns: The training state. """ t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim)) if self.condition_dims is None: cond = None else: assert condition_dim > 0, "Condition dimension must be positive." cond = jnp.ones((1, condition_dim)) params = self.init(rng, t, x, cond)["params"] return train_state.TrainState.create( apply_fn=self.apply, params=params, tx=optimizer )