ott.neural.networks.velocity_field.unet.UNet#
- class ott.neural.networks.velocity_field.unet.UNet(*, shape, model_channels, num_res_blocks, attention_resolutions, out_channels=None, dropout=0.0, channel_mult=(1, 2, 4, 8), time_embed_dim=None, conv_resample=True, num_heads=1, num_heads_upsample=None, resblock_updown=False, num_classes=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, attn_implementation=None, rngs)[source]#
UNet model with attention and timestep embedding.
- Parameters:
shape (
Tuple[int,int,int]) – Input shape[height, width, channels].model_channels (
int) – Number of model channels.num_res_blocks (
int) – Number of residual blocks.attention_resolutions (
Tuple[int,...]) – Resolutions at which to add self-attention.out_channels (
Optional[int]) – Number of output channels. IfNone, use input channels.dropout (
float) – Dropout rate.channel_mult (
Tuple[float,...]) – Multiplier formodel_channelsfor each resolution.time_embed_dim (
Union[int,float,None]) – Dimensionality of the time embedding. IfNone, use4 * model_channels. Iffloat, useint(time_embed_dim * model_channels).conv_resample (
bool) – IfFalse, don’t use convolution for upsampling and use average pooling for downsampling instead of using convolution.num_heads (
int) – Number of attention heads for the input and middle blocks.num_heads_upsample (
Optional[int]) – Number of attention heads for the output blocks. IfNone, usenum_heads.resblock_updown (
bool) – Whether to use residual blocks for up/downsampling.param_dtype (
dtype) – Data type for parameters.attn_implementation (
Optional[Literal['xla','cudnn']]) – Attention implementation fordot_product_attention().rngs (
Rngs) – Random number generator for initialization.args (Any)
kwargs (Any)
- Return type:
Any
Methods
eval(**attributes)Sets the Module to evaluation mode.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.