ott.neural.networks.velocity_field.mlp.MLP

Contents

ott.neural.networks.velocity_field.mlp.MLP#

class ott.neural.networks.velocity_field.mlp.MLP(dim, *, hidden_dims=(), cond_dim=0, act_fn=<PjitFunction of <function silu>>, time_enc_num_freqs=None, dropout_rate=0.0, rngs, **kwargs)[source]#

MLP velocity field.

Parameters:
  • dim (int) – Dimensionality of the velocity field.

  • hidden_dims (Sequence[int]) – Hidden dimensions.

  • cond_dim (int) – Dimensionality of the condition vector.

  • act_fn (Callable[[Array], Array]) – Activation function.

  • dropout_rate (float) – Dropout rate.

  • rngs (Rngs) – Random number generator used for initialization.

  • kwargs (Any) – Keyword arguments for Linear.

  • time_enc_num_freqs (Optional[int])

  • args (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.