ott.core.icnn.ICNN.apply#

ICNN.apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

model = Transformer()
encoded = model.apply({'params': params}, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

encoded = model.apply({'params': params}, x, method=model.encode)

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

def other_fn(instance, ...):
  instance.some_module_attr(...)
  ...

model.apply({'params': params}, x, method=other_fn)
Parameters
  • variables (Mapping[str, Mapping[str, Any]]) – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments passed to the specified apply method.

  • rngs (Optional[Dict[str, Any]]) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.

  • method (Optional[Callable[..., Any]]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the specified apply method.

Return type

Union[Any, Tuple[Any, FrozenDict[str, Mapping[str, Any]]]]

Returns

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.