ott.neural.networks.velocity_field.VelocityField.apply

ott.neural.networks.velocity_field.VelocityField.apply#

VelocityField.apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

>>> class Transformer(nn.Module):
...   def encode(self, x):
...     ...

>>> x = jnp.ones((16, 9))
>>> model = Transformer()
>>> variables = model.init(jax.random.key(0), x, method=Transformer.encode)

>>> encoded = model.apply(variables, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

>>> encoded = model.apply(variables, x, method=model.encode)

You can also pass a string to a callable attribute of the module. For example, the previous can be written as:

>>> encoded = model.apply(variables, x, method='encode')

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

>>> def other_fn(instance, x):
...   # instance.some_module_attr(...)
...   instance.encode
...   ...

>>> model.apply(variables, x, method=other_fn)
Parameters:
  • variables (Mapping[str, Mapping[str, Any]]) – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments passed to the specified apply method.

  • rngs (Union[Array, Dict[str, Array], None]) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.

  • method (Union[Callable[..., Any], str, None]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module. A string can also be provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the specified apply method.

Return type:

Union[Any, Tuple[Any, Union[FrozenDict[str, Mapping[str, Any]], Dict[str, Any]]]]

Returns:

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.