ott.neural.networks.potentials.PotentialMLP

Contents

ott.neural.networks.potentials.PotentialMLP#

class ott.neural.networks.potentials.PotentialMLP(dim_hidden, is_potential=True, act_fn=<PjitFunction of <function leaky_relu>>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Potential MLP.

Parameters:
  • dim_hidden (Sequence[int]) – sequence specifying size of hidden dimensions. The output dimension of the last layer is automatically set to 1 if is_potential is True, or the dimension of the input otherwise.

  • is_potential (bool) – Model the potential if True, otherwise model the gradient of the potential.

  • act_fn (Callable[[Array], Array]) – Activation function.

  • parent (Union[Type[Module], Scope, Type[_Sentinel], None])

  • name (Optional[str])

Methods

act_fn([negative_slope])

Leaky rectified linear unit activation function.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

create_train_state(rng, optimizer, input, ...)

Create initial training state.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

potential_gradient_fn(params)

Return a function returning a vector or the gradient of the potential.

potential_value_fn(params[, ...])

Return a function giving the value of the potential.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

is_potential

name

parent

path

scope

variables

Returns the variables in this module.

dim_hidden