ott.neural.networks.icnn.ICNN.param#
- ICNN.param(name, init_fn, *init_args, unbox=True, **init_kwargs)#
Declares and returns a parameter in this Module.
Parameters are read-only variables in the collection named “params”. See
flax.core.variables
for more details on variables.The first argument of
init_fn
is assumed to be a PRNG key, which is provided automatically and does not have to be passed usinginit_args
orinit_kwargs
:>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... mean = self.param('mean', nn.initializers.lecun_normal(), x.shape) ... ... ... return x * mean >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}
In the example above, the function
lecun_normal
expects two arguments:key
andshape
, but onlyshape
has to be provided explicitly;key
is set automatically using the PRNG forparams
that is passed when initializing the module usinginit()
.- Parameters:
name (
str
) – The parameter name.init_fn (
Callable
[...
,TypeVar
(T
)]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.*init_args – The positional arguments to pass to init_fn.
unbox (
bool
) – If True,AxisMetadata
instances are replaced by their unboxed value, seeflax.nn.meta.unbox
(default: True).**init_kwargs – The key-word arguments to pass to init_fn.
- Return type:
- Returns:
The value of the initialized parameter. Throws an error if the parameter exists already.