ott.neural.networks.potentials.PotentialMLP.init#
- PotentialMLP.init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#
Initializes a module method with variables and returns modified variables.
init
takes as first argument either a singlePRNGKey
, or a dictionary mapping variable collections names to theirPRNGKeys
, and will callmethod
(which is the module’s__call__
function by default) passing*args
and**kwargs
, and returns a dictionary of initialized variables.Example:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(16)(x) ... x = nn.BatchNorm(use_running_average=not train)(x) ... x = nn.relu(x) ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> key = jax.random.key(0) >>> variables = module.init(key, x, train=False)
If you pass a single
PRNGKey
, Flax will use it to feed the'params'
RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its correspondingPRNGKey
toinit
. Ifself.make_rng(name)
is called on an RNG stream name that isn’t passed by the user, it will default to using the'params'
RNG stream.Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... other_variable = self.variable( ... 'other_collection', ... 'other_variable', ... lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape), ... x, ... ) ... x = x + other_variable.value ... ... return nn.Dense(1)(x) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)} >>> variables0 = module.init(rngs, x) >>> rngs['other_rng'] = jax.random.key(0) >>> variables1 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables0['params'], variables1['params'] ... ) >>> # different other_variable (key(1) vs key(0)) >>> np.testing.assert_raises( ... AssertionError, ... np.testing.assert_allclose, ... variables0['other_collection']['other_variable'], ... variables1['other_collection']['other_variable'], ... ) >>> del rngs['other_rng'] >>> # self.make_rng('other_rng') will default to using the 'params' RNG stream >>> variables2 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables1['params'], variables2['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables1['other_collection']['other_variable'], ... variables2['other_collection']['other_variable'], ... ) >>> # passing in a single key is equivalent to passing in {'params': key} >>> variables3 = module.init(jax.random.key(0), x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables2['params'], variables3['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables2['other_collection']['other_variable'], ... variables3['other_collection']['other_variable'], ... )
Jitting
init
initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:>>> module = nn.Dense(1) >>> init_jit = jax.jit(module.init) >>> variables = init_jit(jax.random.key(0), x)
init
is a light wrapper overapply
, so otherapply
arguments likemethod
,mutable
, andcapture_intermediates
are also available.- Parameters:
rngs (
Array
|dict
[str
,Array
]) – The rngs for the variable collections.*args – Named arguments passed to the init function.
method (
Union
[Callable
[...
,Any
],str
,None
]) – An optional method. If provided, applies this method. If not provided, applies the__call__
method. A string can also be provided to specify a method by name.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except “intermediates” are mutable.capture_intermediates (
bool
|Callable
[[Module,str
],bool
]) – IfTrue
, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the init function.
- Return type:
- Returns:
The initialized variable dict.