ott.core.icnn.ICNN.init#

ICNN.init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns modified variables.

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

jit_init = jax.jit(SomeModule(...).init)
jit_init(rng, jnp.ones(input_shape, jnp.float32))
Parameters
  • rngs (Union[Any, Dict[str, Any]]) – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method (Optional[Callable[..., Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Return type

FrozenDict[str, Mapping[str, Any]]

Returns

The initialized variable dict.