ott.neural.networks.velocity_field.VelocityField.init

ott.neural.networks.velocity_field.VelocityField.init#

VelocityField.init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns modified variables.

init takes as first argument either a single PRNGKey, or a dictionary mapping variable collections names to their PRNGKeys, and will call method (which is the module’s __call__ function by default) passing *args and **kwargs, and returns a dictionary of initialized variables.

Example:

>>> import flax.linen as nn
>>> import jax.numpy as jnp
>>> import jax

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, jnp.empty((1, 7)), train=False)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you must pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to init.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...
...     # Add gaussian noise
...     noise_key = self.make_rng('noise')
...     x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> rngs = {'params': jax.random.key(0),
...         'noise': jax.random.key(1)}
>>> variables = module.init(rngs, jnp.empty((1, 7)), train=False)

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), jnp.empty((1, 7)))

init is a light wrapper over apply, so other apply arguments like method, mutable, and capture_intermediates are also available.

Parameters:
  • rngs (Union[Array, Dict[str, Array]]) – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method (Union[Callable[..., Any], str, None]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method. A string can also be provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Return type:

Union[FrozenDict[str, Mapping[str, Any]], Dict[str, Any]]

Returns:

The initialized variable dict.