ott.neural.networks.velocity_field.unet.UNet

Contents

ott.neural.networks.velocity_field.unet.UNet#

class ott.neural.networks.velocity_field.unet.UNet(*, shape, model_channels, num_res_blocks, attention_resolutions, out_channels=None, dropout=0.0, channel_mult=(1, 2, 4, 8), time_embed_dim=None, conv_resample=True, num_heads=1, num_heads_upsample=None, resblock_updown=False, num_classes=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, attn_implementation=None, rngs)[source]#

UNet model with attention and timestep embedding.

Parameters:
  • shape (Tuple[int, int, int]) – Input shape [height, width, channels].

  • model_channels (int) – Number of model channels.

  • num_res_blocks (int) – Number of residual blocks.

  • attention_resolutions (Tuple[int, ...]) – Resolutions at which to add self-attention.

  • out_channels (Optional[int]) – Number of output channels. If None, use input channels.

  • dropout (float) – Dropout rate.

  • channel_mult (Tuple[float, ...]) – Multiplier for model_channels for each resolution.

  • time_embed_dim (Union[int, float, None]) – Dimensionality of the time embedding. If None, use 4 * model_channels. If float, use int(time_embed_dim * model_channels).

  • conv_resample (bool) – If False, don’t use convolution for upsampling and use average pooling for downsampling instead of using convolution.

  • num_heads (int) – Number of attention heads for the input and middle blocks.

  • num_heads_upsample (Optional[int]) – Number of attention heads for the output blocks. If None, use num_heads.

  • resblock_updown (bool) – Whether to use residual blocks for up/downsampling.

  • num_classes (Optional[int]) – Number of classes.

  • dtype (Optional[dtype]) – Data type for computation.

  • param_dtype (dtype) – Data type for parameters.

  • attn_implementation (Optional[Literal['xla', 'cudnn']]) – Attention implementation for dot_product_attention().

  • rngs (Rngs) – Random number generator for initialization.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.