ott.neural.networks.velocity_field.mlp.MLP#
- class ott.neural.networks.velocity_field.mlp.MLP(dim, *, hidden_dims=(), cond_dim=0, act_fn=<PjitFunction of <function silu>>, time_enc_num_freqs=None, dropout_rate=0.0, rngs, **kwargs)[source]#
MLP velocity field.
- Parameters:
- Return type:
Any
Methods
eval(**attributes)Sets the Module to evaluation mode.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.