ott.neural.networks.potentials.PotentialMLP.apply#
- PotentialMLP.apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#
Applies a module method to variables and returns output and modified variables.
Note that
method
should be set if one would like to callapply
on a different class method than__call__
. For instance, suppose a Transformer modules has a method calledencode
, then the following callsapply
on that method:>>> class Transformer(nn.Module): ... def encode(self, x): ... ... >>> x = jnp.ones((16, 9)) >>> model = Transformer() >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) >>> encoded = model.apply(variables, x, method=Transformer.encode)
If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:
>>> encoded = model.apply(variables, x, method=model.encode)
You can also pass a string to a callable attribute of the module. For example, the previous can be written as:
>>> encoded = model.apply(variables, x, method='encode')
Note
method
can also be a function that is not defined inTransformer
. In that case, the function should have at least one argument representing an instance of the Module class:>>> def other_fn(instance, x): ... # instance.some_module_attr(...) ... instance.encode ... ... >>> model.apply(variables, x, method=other_fn)
- Parameters:
variables (
Mapping
[str
,Mapping
[str
,Any
]]) – A dictionary containing variables keyed by variable collections. Seeflax.core.variables
for more details about variables.*args – Named arguments passed to the specified apply method.
rngs (
Union
[Array
,Dict
[str
,Array
],None
]) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.method (
Union
[Callable
[...
,Any
],str
,None
]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the__call__
method of the module. A string can also be provided to specify a method by name.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.capture_intermediates (
Union
[bool
,Callable
[[Module
,str
],bool
]]) – IfTrue
, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the specified apply method.
- Return type:
Union
[Any
,Tuple
[Any
,Union
[FrozenDict
[str
,Mapping
[str
,Any
]],Dict
[str
,Any
]]]]- Returns:
If
mutable
is False, returns output. If any collections are mutable, returns(output, vars)
, wherevars
are is a dict of the modified collections.