ott.solvers.nn.icnn.ICNN.init
ott.solvers.nn.icnn.ICNN.init#
- ICNN.init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#
Initializes a module method with variables and returns modified variables.
Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:
jit_init = jax.jit(SomeModule(...).init) jit_init(rng, jnp.ones(input_shape, jnp.float32))
- Parameters
rngs (
Union
[Any
,Dict
[str
,Any
]]) – The rngs for the variable collections.*args – Named arguments passed to the init function.
method (
Optional
[Callable
[...
,Any
]]) – An optional method. If provided, applies this method. If not provided, applies the__call__
method.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except “intermediates” are mutable.capture_intermediates (
Union
[bool
,Callable
[[Module
,str
],bool
]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the init function.
- Return type
FrozenDict
[str
,Mapping
[str
,Any
]]- Returns
The initialized variable dict.