ott.neural.networks.potentials.PotentialMLP.init

ott.neural.networks.potentials.PotentialMLP.init#

PotentialMLP.init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns modified variables.

init takes as first argument either a single PRNGKey, or a dictionary mapping variable collections names to their PRNGKeys, and will call method (which is the module’s __call__ function by default) passing *args and **kwargs, and returns a dictionary of initialized variables.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, x, train=False)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to init. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     other_variable = self.variable(
...       'other_collection',
...       'other_variable',
...       lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape),
...       x,
...     )
...     x = x + other_variable.value
...
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)}
>>> variables0 = module.init(rngs, x)

>>> rngs['other_rng'] = jax.random.key(0)
>>> variables1 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables0['params'], variables1['params']
... )
>>> # different other_variable (key(1) vs key(0))
>>> np.testing.assert_raises(
...   AssertionError,
...   np.testing.assert_allclose,
...   variables0['other_collection']['other_variable'],
...   variables1['other_collection']['other_variable'],
... )

>>> del rngs['other_rng']
>>> # self.make_rng('other_rng') will default to using the 'params' RNG stream
>>> variables2 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables1['params'], variables2['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables1['other_collection']['other_variable'],
...   variables2['other_collection']['other_variable'],
... )

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> variables3 = module.init(jax.random.key(0), x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables2['params'], variables3['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables2['other_collection']['other_variable'],
...   variables3['other_collection']['other_variable'],
... )

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), x)

init is a light wrapper over apply, so other apply arguments like method, mutable, and capture_intermediates are also available.

Parameters:
  • rngs (Array | dict[str, Array]) – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method (Union[Callable[..., Any], str, None]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method. A string can also be provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates (bool | Callable[[Module, str], bool]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Return type:

FrozenDict[str, Mapping[str, Any]] | dict[str, Any]

Returns:

The initialized variable dict.