ott.solvers.nn.models.ICNN.apply#
- ICNN.apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#
Applies a module method to variables and returns output and modified variables.
Note that method should be set if one would like to call apply on a different class method than
__call__
. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:model = Transformer() encoded = model.apply({'params': params}, x, method=Transformer.encode)
If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:
encoded = model.apply({'params': params}, x, method=model.encode)
You can also pass a string to a callable attribute of the module. For example, the previous can be written as:
encoded = model.apply({'params': params}, x, method='encode')
Note
method
can also be a function that is not defined inTransformer
. In that case, the function should have at least one argument representing an instance of the Module class:def other_fn(instance, ...): instance.some_module_attr(...) ... model.apply({'params': params}, x, method=other_fn)
- Parameters:
variables (
Mapping
[str
,Mapping
[str
,Any
]]) – A dictionary containing variables keyed by variable collections. Seeflax.core.variables
for more details about variables.*args – Named arguments passed to the specified apply method.
rngs (
Optional
[Dict
[str
,Any
]]) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.method (
Union
[Callable
[...
,Any
],str
,None
]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the__call__
method of the module. A string can also be provided to specify a method by name.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.capture_intermediates (
Union
[bool
,Callable
[[Module
,str
],bool
]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the specified apply method.
- Return type:
Union
[Any
,Tuple
[Any
,Union
[FrozenDict
[str
,Mapping
[str
,Any
]],Dict
[str
,Any
]]]]- Returns:
If
mutable
is False, returns output. If any collections are mutable, returns(output, vars)
, wherevars
are is a dict of the modified collections.