ott.neural.networks.potentials.PotentialMLP#
- class ott.neural.networks.potentials.PotentialMLP(dim_hidden, *, input_dim, is_potential=True, act_fn=<PjitFunction of <function leaky_relu>>, rngs)[source]#
Potential MLP (NNX).
- Parameters:
dim_hidden (
Sequence[int]) – Sequence specifying the size of hidden dimensions. The output dimension of the last layer is automatically set to 1 ifis_potentialisTrue, or the dimension of the input otherwise.input_dim (
int) – Dimensionality of the input.is_potential (
bool) – Model the potential ifTrue, otherwise model the gradient of the potential.rngs (
Rngs) – NNX random number generators.args (Any)
kwargs (Any)
- Return type:
Any
Methods
eval(**attributes)Sets the Module to evaluation mode.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
Return a callable giving the gradient of the potential.
potential_value_fn([other_potential_value_fn])Return a callable giving the potential value.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.
Attributes
Trueif the module defines a scalar potential value.