ott.neural.networks.potentials.BasePotential.apply#
- BasePotential.apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#
Applies a module method to variables and returns output and modified variables.
Note that
method
should be set if one would like to callapply
on a different class method than__call__
. For instance, suppose a Transformer modules has a method calledencode
, then the following callsapply
on that method:>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Transformer(nn.Module): ... def encode(self, x): ... ... >>> x = jnp.ones((16, 9)) >>> model = Transformer() >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) >>> encoded = model.apply(variables, x, method=Transformer.encode)
If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:
>>> encoded = model.apply(variables, x, method=model.encode)
You can also pass a string to a callable attribute of the module. For example, the previous can be written as:
>>> encoded = model.apply(variables, x, method='encode')
Note
method
can also be a function that is not defined inTransformer
. In that case, the function should have at least one argument representing an instance of the Module class:>>> def other_fn(instance, x): ... # instance.some_module_attr(...) ... instance.encode ... ... >>> model.apply(variables, x, method=other_fn)
If you pass a single
PRNGKey
, Flax will use it to feed the'params'
RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its correspondingPRNGKey
toapply
. Ifself.make_rng(name)
is called on an RNG stream name that isn’t passed by the user, it will default to using the'params'
RNG stream.Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, add_noise=False): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... if add_noise: ... # Add gaussian noise ... noise_key = self.make_rng('noise') ... x = x + jax.random.normal(noise_key, x.shape) ... ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} >>> variables = module.init(rngs, x) >>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> rngs['noise'] = jax.random.key(0) >>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # different output (key(1) vs key(0)) >>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1) >>> del rngs['noise'] >>> # self.make_rng('noise') will default to using the 'params' RNG stream >>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # same output (key(0)) >>> np.testing.assert_allclose(out1, out2) >>> # passing in a single key is equivalent to passing in {'params': key} >>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0)) >>> # same output (key(0)) >>> np.testing.assert_allclose(out2, out3)
- Parameters:
variables (
Mapping
[str
,Mapping
[str
,Any
]]) – A dictionary containing variables keyed by variable collections. Seeflax.core.variables
for more details about variables.*args – Named arguments passed to the specified apply method.
rngs (
Union
[Array
,dict
[str
,Array
],None
]) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.method (
Union
[Callable
[...
,Any
],str
,None
]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the__call__
method of the module. A string can also be provided to specify a method by name.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.capture_intermediates (
bool
|Callable
[[Module,str
],bool
]) – IfTrue
, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the specified apply method.
- Return type:
Union
[Any
,tuple
[Any
,FrozenDict
[str
,Mapping
[str
,Any
]] |dict
[str
,Any
]]]- Returns:
If
mutable
is False, returns output. If any collections are mutable, returns(output, vars)
, wherevars
are is a dict of the modified collections.