Source code for ott.neural.networks.velocity_field.unet

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modified from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/unet.py."""
import abc
import functools
import math
from typing import Any, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp

from flax import nnx

__all__ = ["UNet"]


def timestep_embedding(
    timesteps: jax.Array, dim: int, max_period: int = 10000
) -> jax.Array:
  half = dim // 2
  freqs = jnp.exp(
      -math.log(max_period) *
      jnp.arange(start=0, stop=half, dtype=jnp.float32) / half
  )
  args = timesteps[:, None].astype(jnp.float32) * freqs[None]
  embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
  if dim % 2:
    embedding = jnp.concatenate([embedding,
                                 jnp.zeros_like(embedding[:, :1])],
                                axis=-1)
  return embedding


class GroupNorm32(nnx.GroupNorm):

  def __call__(
      self, x: jax.Array, *, mask: Optional[jax.Array] = None
  ) -> jax.Array:
    return super().__call__(x.astype(jnp.float32), mask=mask).astype(x.dtype)


def conv_nd(
    dims: int,
    in_channels: Union[int, Tuple[int, ...]],
    out_channels: Union[int, Tuple[int, ...]],
    kernel_size: Union[int, Tuple[int, ...]],
    strides: Union[int, Tuple[int, ...]] = 1,
    *,
    dtype: Optional[jnp.dtype] = None,
    param_dtype: jnp.dtype = jnp.float32,
    padding: Union[int, Tuple[int, ...]] = 0,
    zero_init: bool = False,
    rngs: nnx.Rngs,
    **kwargs: Any,
) -> nnx.Conv:
  if isinstance(kernel_size, int):
    kernel_size = (kernel_size,) * dims
  if isinstance(strides, int):
    strides = (strides,) * dims
  if isinstance(padding, int):
    padding = (padding,) * dims

  if zero_init:
    kwargs["kernel_init"] = nnx.initializers.constant(value=0.0)
    kwargs["bias_init"] = nnx.initializers.constant(value=0.0)

  return nnx.Conv(
      in_features=in_channels,
      out_features=out_channels,
      kernel_size=kernel_size,
      strides=strides,
      padding=padding,
      dtype=dtype,
      param_dtype=param_dtype,
      rngs=rngs,
      **kwargs,
  )


def normalization(
    channels: int,
    *,
    dtype: Optional[jnp.dtype] = None,
    param_dtype: jnp.dtype = jnp.float32,
    rngs: nnx.Rngs,
) -> nnx.GroupNorm:
  return GroupNorm32(
      num_groups=32,
      num_features=channels,
      epsilon=1e-5,
      dtype=dtype,
      param_dtype=param_dtype,
      use_fast_variance=False,
      rngs=rngs,
  )


class TimestepBlock(nnx.Module):

  @abc.abstractmethod
  def __call__(
      self,
      x: jax.Array,
      emb: jax.Array,
      *,
      rngs: Optional[nnx.Rngs] = None,
  ) -> jax.Array:
    pass


class TimestepEmbedSequential(nnx.Module):

  def __init__(self, *layers: nnx.Module):
    super().__init__()
    self.layers = nnx.List(layers) if hasattr(nnx, "List") else list(layers)

  def __call__(
      self,
      x: jax.Array,
      emb: Optional[jax.Array] = None,
      *,
      rngs: Optional[nnx.Rngs] = None,
  ) -> jax.Array:
    for layer in self.layers:
      if isinstance(layer, TimestepBlock):
        x = layer(x, emb, rngs=rngs)
      else:
        x = layer(x)
    return x


class Upsample(nnx.Module):

  def __init__(
      self,
      channels: int,
      use_conv: bool,
      *,
      out_channels: Optional[int] = None,
      dtype: Optional[jnp.dtype] = None,
      param_dtype: jnp.dtype = jnp.float32,
      rngs: nnx.Rngs,
  ):
    super().__init__()
    self.channels = channels
    self.out_channels = out_channels or channels
    if use_conv:
      self.conv = conv_nd(
          2,
          self.channels,
          self.out_channels,
          3,
          padding=1,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )
    else:
      self.conv = None

  def __call__(self, x: jax.Array) -> jax.Array:
    assert x.shape[-1] == self.channels
    b, h, w, c = x.shape
    shape = (b, 2 * h, 2 * w, c)
    x = jax.image.resize(x, shape, method="nearest")
    if self.conv is not None:
      x = self.conv(x)
    return x


class Downsample(nnx.Module):

  def __init__(
      self,
      channels: int,
      use_conv: bool,
      *,
      out_channels: Optional[int] = None,
      dtype: Optional[jnp.dtype] = None,
      param_dtype: jnp.dtype = jnp.float32,
      rngs: nnx.Rngs,
  ):
    super().__init__()
    self.channels = channels
    self.out_channels = out_channels or channels
    if use_conv:
      self.op = conv_nd(
          2,
          self.channels,
          self.out_channels,
          3,
          strides=2,
          padding=1,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )
    else:
      self.op = functools.partial(
          nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)
      )

  def __call__(self, x: jax.Array) -> jax.Array:
    assert x.shape[-1] == self.channels
    return self.op(x)


class ResBlock(TimestepBlock):

  def __init__(
      self,
      channels: int,
      emb_channels: int,
      dropout: float,
      *,
      out_channels: Optional[int] = None,
      use_conv: bool = False,
      up: bool = False,
      down: bool = False,
      dtype: Optional[jnp.dtype] = None,
      param_dtype: jnp.dtype = jnp.float32,
      rngs: nnx.Rngs,
  ):
    super().__init__()
    self.channels = channels
    self.emb_channels = emb_channels
    self.dropout = dropout
    self.out_channels = out_channels or channels
    self.use_conv = use_conv
    self.updown = up or down

    self.in_norm = normalization(
        channels, dtype=dtype, param_dtype=param_dtype, rngs=rngs
    )
    self.in_act = nnx.silu
    self.in_conv = conv_nd(
        2,
        channels,
        self.out_channels,
        3,
        padding=1,
        dtype=dtype,
        param_dtype=param_dtype,
        rngs=rngs,
    )

    if up:
      self.h_upd = Upsample(
          channels,
          use_conv=False,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )
      self.x_upd = Upsample(
          channels,
          use_conv=False,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )
    elif down:
      self.h_upd = Downsample(
          channels,
          use_conv=False,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )
      self.x_upd = Downsample(
          channels,
          use_conv=False,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )
    else:
      self.h_upd = lambda x: x
      self.x_upd = lambda x: x

    self.emb_act = nnx.silu
    self.emb_layers = nnx.Linear(
        emb_channels,
        self.out_channels,
        dtype=dtype,
        param_dtype=param_dtype,
        rngs=rngs,
    )

    self.out_norm = normalization(
        self.out_channels, dtype=dtype, param_dtype=param_dtype, rngs=rngs
    )
    self.out_act = nnx.silu
    self.out_dropout = nnx.Dropout(rate=dropout)
    self.out_conv = conv_nd(
        2,
        self.out_channels,
        self.out_channels,
        3,
        padding=1,
        zero_init=True,
        dtype=dtype,
        param_dtype=param_dtype,
        rngs=rngs,
    )

    if self.out_channels == channels:
      self.skip_connection = lambda x: x
    elif use_conv:
      self.skip_connection = conv_nd(
          2,
          channels,
          self.out_channels,
          3,
          padding=1,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )
    else:
      self.skip_connection = conv_nd(
          2,
          channels,
          self.out_channels,
          1,
          dtype=dtype,
          param_dtype=param_dtype,
          rngs=rngs,
      )

  def __call__(
      self,
      x: jax.Array,
      emb: jax.Array,
      *,
      rngs: Optional[nnx.Rngs] = None,
  ) -> jax.Array:
    if self.updown:
      h = self.in_norm(x)
      h = self.in_act(h)
      h = self.h_upd(h)
      x = self.x_upd(x)
      h = self.in_conv(h)
    else:
      h = self.in_norm(x)
      h = self.in_act(h)
      h = self.in_conv(h)

    emb_out = self.emb_act(emb).astype(h.dtype)
    emb_out = self.emb_layers(emb_out)
    emb_out = emb_out[:, None, None, :]  # [b, 1, 1, t_emb]
    h = h + emb_out
    h = self.out_norm(h)
    h = self.out_act(h)
    h = self.out_dropout(h, rngs=rngs)
    h = self.out_conv(h)

    return self.skip_connection(x) + h


class QKVAttention(nnx.Module):

  def __init__(
      self,
      n_heads: int,
      attn_implementation: Optional[Literal["xla", "cudnn"]] = None,
  ):
    super().__init__()
    self.n_heads = n_heads
    self.attn_implementation = attn_implementation

  def __call__(self, qkv: jax.Array) -> jax.Array:
    bs, length, width = qkv.shape
    head_dim, rest = divmod(width, 3 * self.n_heads)
    assert rest == 0, rest
    scale = 1.0 / math.sqrt(math.sqrt(head_dim))

    # Split into heads and channels
    q, k, v = jnp.split(qkv, 3, axis=-1)
    q = q.reshape(bs, length, self.n_heads, head_dim)
    k = k.reshape(bs, length, self.n_heads, head_dim)
    v = v.reshape(bs, length, self.n_heads, head_dim)

    dtype = jnp.bfloat16 if self.attn_implementation == "cudnn" else q.dtype
    a = jax.nn.dot_product_attention(
        q.astype(dtype),
        k.astype(dtype),
        v.astype(dtype),
        scale=scale,
        implementation=self.attn_implementation,
    ).astype(q.dtype)
    return a.reshape(bs, length, self.n_heads * head_dim)


class AttentionBlock(nnx.Module):

  def __init__(
      self,
      channels: int,
      *,
      num_heads: int = 1,
      attn_implementation: Optional[Literal["xla", "cudnn"]] = None,
      dtype: Optional[jnp.dtype] = None,
      param_dtype: jnp.dtype = jnp.float32,
      rngs: nnx.Rngs,
  ):
    super().__init__()
    self.channels = channels
    self.num_heads = num_heads
    self.norm = normalization(
        channels, dtype=dtype, param_dtype=param_dtype, rngs=rngs
    )
    self.qkv = conv_nd(
        1,
        channels,
        channels * 3,
        1,
        dtype=dtype,
        param_dtype=param_dtype,
        rngs=rngs,
    )

    # split qkv before split heads
    self.attention = QKVAttention(
        self.num_heads, attn_implementation=attn_implementation
    )

    self.proj_out = conv_nd(
        1,
        channels,
        channels,
        1,
        zero_init=True,
        dtype=dtype,
        param_dtype=param_dtype,
        rngs=rngs,
    )

  def __call__(self, x: jax.Array) -> jax.Array:
    # [B, H, W, C]
    b, *spatial, c = x.shape
    x = x.reshape(b, -1, c)  # [B, H * W, C]
    qkv = self.qkv(self.norm(x))
    h = self.attention(qkv)
    h = self.proj_out(h)
    return (x + h).reshape(b, *spatial, c)


[docs] class UNet(nnx.Module): """UNet model with attention and timestep embedding. Args: shape: Input shape ``[height, width, channels]``. model_channels: Number of model channels. num_res_blocks: Number of residual blocks. attention_resolutions: Resolutions at which to add self-attention. out_channels: Number of output channels. If :obj:`None`, use input channels. dropout: Dropout rate. channel_mult: Multiplier for ``model_channels`` for each resolution. time_embed_dim: Dimensionality of the time embedding. If :obj:`None`, use ``4 * model_channels``. If :class:`float`, use ``int(time_embed_dim * model_channels)``. conv_resample: If :obj:`False`, don't use convolution for upsampling and use average pooling for downsampling instead of using convolution. num_heads: Number of attention heads for the input and middle blocks. num_heads_upsample: Number of attention heads for the output blocks. If :obj:`None`, use ``num_heads``. resblock_updown: Whether to use residual blocks for up/downsampling. num_classes: Number of classes. dtype: Data type for computation. param_dtype: Data type for parameters. attn_implementation: Attention implementation for :func:`~jax.nn.dot_product_attention`. rngs: Random number generator for initialization. """ def __init__( self, *, shape: Tuple[int, int, int], model_channels: int, num_res_blocks: int, attention_resolutions: Tuple[int, ...], out_channels: Optional[int] = None, dropout: float = 0.0, channel_mult: Tuple[float, ...] = (1, 2, 4, 8), time_embed_dim: Optional[Union[int, float]] = None, conv_resample: bool = True, num_heads: int = 1, num_heads_upsample: Optional[int] = None, resblock_updown: bool = False, num_classes: Optional[int] = None, dtype: Optional[jnp.dtype] = None, param_dtype: jnp.dtype = jnp.float32, attn_implementation: Optional[Literal["xla", "cudnn"]] = None, rngs: nnx.Rngs, ): super().__init__() image_size, _, in_channels = shape out_channels = out_channels or in_channels attention_resolutions = tuple( image_size // res for res in attention_resolutions ) num_heads_upsample = num_heads_upsample or num_heads self.dtype = dtype self.in_channels = in_channels self.model_channels = model_channels self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_heads = num_heads self.num_heads_upsample = num_heads_upsample # Time embedding if time_embed_dim is None: time_embed_dim = model_channels * 4 elif isinstance(time_embed_dim, float): time_embed_dim = int(time_embed_dim * model_channels) assert isinstance(time_embed_dim, int), time_embed_dim self.time_embed = TimestepEmbedSequential( nnx.Linear( model_channels, time_embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ), nnx.silu, nnx.Linear( time_embed_dim, time_embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ), ) # condition embedding if num_classes is not None: self.label_emb = nnx.Embed( num_embeddings=num_classes, features=time_embed_dim, rngs=rngs ) else: self.label_emb = None # Input blocks self.input_blocks = nnx.List() if hasattr(nnx, "List") else [] input_block_chans = [model_channels] ch = int(channel_mult[0] * model_channels) ds = 1 # First input block (just convolution) self.input_blocks.append( TimestepEmbedSequential( conv_nd( 2, in_channels, ch, 3, padding=1, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) ) ) # Rest of input blocks for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=int(mult * model_channels), dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) ] ch = int(mult * model_channels) if ds in attention_resolutions: layers.append( AttentionBlock( ch, num_heads=num_heads, attn_implementation=attn_implementation, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) input_block_chans.append(ch) # Downsample if not last level if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, down=True, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) if resblock_updown else Downsample( ch, conv_resample, out_channels=out_ch, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ), AttentionBlock( ch, num_heads=num_heads, attn_implementation=attn_implementation, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ), ResBlock( ch, time_embed_dim, dropout, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ), ) # Output blocks self.output_blocks = nnx.List() if hasattr(nnx, "List") else [] for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout, out_channels=int(model_channels * mult), dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) ] ch = int(model_channels * mult) if ds in attention_resolutions: layers.append( AttentionBlock( ch, num_heads=num_heads_upsample, attn_implementation=attn_implementation, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) ) if level and i == num_res_blocks: out_ch = ch layers.append( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, up=True, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) if resblock_updown else Upsample( ch, conv_resample, out_channels=out_ch, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self.out = TimestepEmbedSequential( normalization(ch, dtype=dtype, param_dtype=param_dtype, rngs=rngs), nnx.silu, conv_nd( 2, ch, out_channels, 3, padding=1, zero_init=True, # dtype=dtype, # don't cast computations to (bf)float16 param_dtype=param_dtype, rngs=rngs, ), ) def __call__( self, t: jax.Array, x: jax.Array, cond: Optional[jax.Array] = None, *, rngs: Optional[nnx.Rngs] = None, ) -> jax.Array: """Compute the velocity. Args: t: Time array of shape ``[batch,]``. x: Image of shape ``[batch, height, width, channels]``. cond: Class condition array of shape ``[batch,]``. rngs: Random number generator for dropout. Returns: The velocity array of shape ``[batch, height, width, channels]``. """ emb = self.time_embed(timestep_embedding(t, self.model_channels), rngs=rngs) # TODO(michalk8): generalize for different types of conditions if self.label_emb is not None: assert cond is not None, "Please provide a condition." # emb is cast to `self.dtype` inside each submodule emb = emb + self.label_emb(cond) h = x.astype(self.dtype) hs = [] for module in self.input_blocks: h = module(h, emb, rngs=rngs) hs.append(h) h = self.middle_block(h, emb, rngs=rngs) for module in self.output_blocks: h = jnp.concatenate([h, hs.pop()], axis=-1) h = module(h, emb, rngs=rngs) h = h.astype(x.dtype) # output's compute dtype return self.out(h, rngs=rngs)