ott.neural.networks.potentials.PotentialMLP.lazy_init#
- PotentialMLP.lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)#
Initializes a module without computing on an actual input.
lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a
jax.ShapeDtypeStruct
which specifies the shape and dtype of the input but no concrete data.Example:
>>> model = nn.Dense(features=256) >>> variables = model.lazy_init( ... jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))
The args and kwargs args passed to
lazy_init
can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwiselazy_init
cannot infer which variables should be initialized.- Parameters:
rngs (
Array
|dict
[str
,Array
]) – The rngs for the variable collections.*args – arguments passed to the init function.
method (
Optional
[Callable
[...
,Any
]]) – An optional method. If provided, applies this method. If not provided, applies the__call__
method.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except “intermediates” are mutable.**kwargs – Keyword arguments passed to the init function.
- Return type:
FrozenDict
[str
,Mapping
[str
,Any
]]- Returns:
The initialized variable dict.