ott.neural.networks.potentials.MLP#
- class ott.neural.networks.potentials.MLP(dim_hidden, *, input_dim, act_fn=<PjitFunction of <function elu>>, rngs)[source]#
A simple MLP (NNX).
- Parameters:
- Return type:
Any
Methods
eval(**attributes)Sets the Module to evaluation mode.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.