ott.neural.networks.potentials.PotentialMLP

Contents

ott.neural.networks.potentials.PotentialMLP#

class ott.neural.networks.potentials.PotentialMLP(dim_hidden, *, input_dim, is_potential=True, act_fn=<PjitFunction of <function leaky_relu>>, rngs)[source]#

Potential MLP (NNX).

Parameters:
  • dim_hidden (Sequence[int]) – Sequence specifying the size of hidden dimensions. The output dimension of the last layer is automatically set to 1 if is_potential is True, or the dimension of the input otherwise.

  • input_dim (int) – Dimensionality of the input.

  • is_potential (bool) – Model the potential if True, otherwise model the gradient of the potential.

  • act_fn (Callable[[Array], Array]) – Activation function.

  • rngs (Rngs) – NNX random number generators.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

potential_gradient_fn()

Return a callable giving the gradient of the potential.

potential_value_fn([other_potential_value_fn])

Return a callable giving the potential value.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.

Attributes

is_potential

True if the module defines a scalar potential value.