ott.neural.networks.potentials.MLP

Contents

ott.neural.networks.potentials.MLP#

class ott.neural.networks.potentials.MLP(dim_hidden, *, input_dim, act_fn=<PjitFunction of <function elu>>, rngs)[source]#

A simple MLP (NNX).

Parameters:
  • dim_hidden (Sequence[int]) – Sequence specifying the size of hidden dimensions, including the output dimension as the last element.

  • input_dim (int) – Dimensionality of the input.

  • act_fn (Callable[[Array], Array]) – Activation function.

  • rngs (Rngs) – NNX random number generators.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.