ott.neural.networks.potentials.BasePotential.apply

ott.neural.networks.potentials.BasePotential.apply#

BasePotential.apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Transformer(nn.Module):
...   def encode(self, x):
...     ...

>>> x = jnp.ones((16, 9))
>>> model = Transformer()
>>> variables = model.init(jax.random.key(0), x, method=Transformer.encode)

>>> encoded = model.apply(variables, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

>>> encoded = model.apply(variables, x, method=model.encode)

You can also pass a string to a callable attribute of the module. For example, the previous can be written as:

>>> encoded = model.apply(variables, x, method='encode')

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

>>> def other_fn(instance, x):
...   # instance.some_module_attr(...)
...   instance.encode
...   ...

>>> model.apply(variables, x, method=other_fn)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to apply. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, add_noise=False):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     if add_noise:
...       # Add gaussian noise
...       noise_key = self.make_rng('noise')
...       x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, x)
>>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs)

>>> rngs['noise'] = jax.random.key(0)
>>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # different output (key(1) vs key(0))
>>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1)

>>> del rngs['noise']
>>> # self.make_rng('noise') will default to using the 'params' RNG stream
>>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # same output (key(0))
>>> np.testing.assert_allclose(out1, out2)

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0))
>>> # same output (key(0))
>>> np.testing.assert_allclose(out2, out3)
Parameters:
  • variables (Mapping[str, Mapping[str, Any]]) – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments passed to the specified apply method.

  • rngs (Union[Array, dict[str, Array], None]) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.

  • method (Union[Callable[..., Any], str, None]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module. A string can also be provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates (bool | Callable[[Module, str], bool]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the specified apply method.

Return type:

Union[Any, tuple[Any, FrozenDict[str, Mapping[str, Any]] | dict[str, Any]]]

Returns:

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.