ott.neural.networks.potentials.BasePotential.apply#
- BasePotential.apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#
Applies a module method to variables and returns output and modified variables.
Note that
methodshould be set if one would like to callapplyon a different class method than__call__. For instance, suppose a Transformer modules has a method calledencode, then the following callsapplyon that method:>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Transformer(nn.Module): ... def encode(self, x): ... ... >>> x = jnp.ones((16, 9)) >>> model = Transformer() >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) >>> encoded = model.apply(variables, x, method=Transformer.encode)
If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:
>>> encoded = model.apply(variables, x, method=model.encode)
You can also pass a string to a callable attribute of the module. For example, the previous can be written as:
>>> encoded = model.apply(variables, x, method='encode')
Note
methodcan also be a function that is not defined inTransformer. In that case, the function should have at least one argument representing an instance of the Module class:>>> def other_fn(instance, x): ... # instance.some_module_attr(...) ... instance.encode ... ... >>> model.apply(variables, x, method=other_fn)
If you pass a single
PRNGKey, Flax will use it to feed the'params'RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its correspondingPRNGKeytoapply. Ifself.make_rng(name)is called on an RNG stream name that isn’t passed by the user, it will default to using the'params'RNG stream.Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, add_noise=False): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... if add_noise: ... # Add gaussian noise ... noise_key = self.make_rng('noise') ... x = x + jax.random.normal(noise_key, x.shape) ... ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} >>> variables = module.init(rngs, x) >>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> rngs['noise'] = jax.random.key(0) >>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # different output (key(1) vs key(0)) >>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1) >>> del rngs['noise'] >>> # self.make_rng('noise') will default to using the 'params' RNG stream >>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # same output (key(0)) >>> np.testing.assert_allclose(out1, out2) >>> # passing in a single key is equivalent to passing in {'params': key} >>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0)) >>> # same output (key(0)) >>> np.testing.assert_allclose(out2, out3)
- Parameters:
variables (
Mapping[str,Mapping[str,Any]]) – A dictionary containing variables keyed by variable collections. Seeflax.core.variablesfor more details about variables.*args – Named arguments passed to the specified apply method.
rngs (
Array|dict[str,Array] |None) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.method (
Callable[...,Any] |str|None) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the__call__method of the module. A string can also be provided to specify a method by name.mutable (
Union[bool,str,Collection[str],DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool: all/no collections are mutable.str: The name of a single mutable collection.list: A list of names of mutable collections.capture_intermediates (
bool|Callable[[Module,str],bool]) – IfTrue, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all__call__methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the specified apply method.
- Return type:
Any|tuple[Any,FrozenDict[str,Mapping[str,Any]] |dict[str,Any]]- Returns:
If
mutableis False, returns output. If any collections are mutable, returns(output, vars), wherevarsare is a dict of the modified collections.