ott.neural.networks.potentials.BasePotential.lazy_init

ott.neural.networks.potentials.BasePotential.lazy_init#

BasePotential.lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)#

Initializes a module without computing on an actual input.

lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a jax.ShapeDtypeStruct which specifies the shape and dtype of the input but no concrete data.

Example:

>>> model = nn.Dense(features=256)
>>> variables = model.lazy_init(
...     jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))

The args and kwargs args passed to lazy_init can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise lazy_init cannot infer which variables should be initialized.

Parameters:
  • rngs (Union[Array, Dict[str, Array]]) – The rngs for the variable collections.

  • *args – arguments passed to the init function.

  • method (Optional[Callable[..., Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • **kwargs – Keyword arguments passed to the init function.

Return type:

FrozenDict[str, Mapping[str, Any]]

Returns:

The initialized variable dict.