Source code for ott.neural.networks.velocity_field.mlp

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Optional, Sequence

import jax
from jax import numpy as jnp

from flax import nnx

__all__ = ["MLP"]


[docs] class MLP(nnx.Module): """MLP velocity field. Args: dim: Dimensionality of the velocity field. hidden_dims: Hidden dimensions. cond_dim: Dimensionality of the condition vector. act_fn: Activation function. dropout_rate: Dropout rate. rngs: Random number generator used for initialization. kwargs: Keyword arguments for :class:`~flax.nnx.Linear`. """ def __init__( self, dim: int, *, hidden_dims: Sequence[int] = (), cond_dim: int = 0, act_fn: Callable[[jax.Array], jax.Array] = nnx.silu, time_enc_num_freqs: Optional[int] = None, dropout_rate: float = 0.0, rngs: nnx.Rngs, **kwargs: Any, ): super().__init__() if time_enc_num_freqs is None: time_enc_num_freqs = max(1, min(dim // 16, 64)) time_dim = 2 * time_enc_num_freqs hidden_dims = [dim + time_dim + cond_dim] + list(hidden_dims) block_fn = functools.partial( Block, act_fn=act_fn, dropout_rate=dropout_rate, **kwargs ) blocks = [ *( block_fn(in_dim, out_dim, rngs=rngs) for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]) ) ] final_block = block_fn( hidden_dims[-1], dim, act_fn=jax.nn.identity, rngs=rngs, ) self.blocks = nnx.Sequential(*blocks, final_block) self.time_enc_num_freqs = time_enc_num_freqs def __call__( self, t: jax.Array, x: jax.Array, cond: Optional[jax.Array] = None, *, rngs: Optional[nnx.Rngs] = None, ) -> jax.Array: """Compute the velocity. Args: t: Time array of shape ``[batch,]``. x: Input array of shape ``[batch, dim]``. cond: Condition array of shape ``[batch, cond_dim]``. rngs: Random number generator for dropout. Returns: The velocity array of shape ``[batch, dim]``. """ t_emb = _encode_time(t, self.time_enc_num_freqs) h = [t_emb, x] if cond is None else [t_emb, x, cond] h = jnp.concatenate(h, axis=-1) return self.blocks(h, rngs=rngs)
class Block(nnx.Module): def __init__( self, in_dim: int, out_dim: int, *, act_fn: Callable[[jax.Array], jax.Array], dropout_rate: float = 0.0, rngs: nnx.Rngs, **kwargs: Any ): super().__init__() self.lin = nnx.Linear(in_dim, out_dim, rngs=rngs, **kwargs) self.act_fn = act_fn self.dropout = nnx.Dropout(dropout_rate) def __call__( self, x: jax.Array, *, rngs: Optional[nnx.Rngs] = None ) -> jax.Array: return self.dropout(self.act_fn(self.lin(x)), rngs=rngs) def _encode_time(t: jax.Array, num_freqs: int) -> jax.Array: assert num_freqs > 0, "Number of frequencies must be positive." assert t.ndim == 1, t.shape freq = (2 * jnp.pi) * jnp.arange(1, num_freqs + 1) t = freq * t[:, None] return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1)