ott.neural.networks.icnn.ICNN.variable#
- ICNN.variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)#
Declares and returns a variable in this Module.
See
flax.core.variables
for more information. See alsoparam()
for a shorthand way to define read-only variables in the “params” collection.Contrary to
param()
, all arguments passing usinginit_fn
should be passed on explicitly:>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... key = self.make_rng('stats') ... mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape) ... ... ... return x * mean.value >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}
In the example above, the function
lecun_normal
expects two arguments:key
andshape
, and both have to be passed on. The PRNG forstats
has to be provided explicitly when callinginit()
andapply()
.- Parameters:
col (
str
) – The variable collection name.name (
str
) – The variable name.init_fn (
Optional
[Callable
[...
,TypeVar
(T
)]]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.*init_args – The positional arguments to pass to init_fn.
unbox (
bool
) – If True,AxisMetadata
instances are replaced by their unboxed value, seeflax.nn.meta.unbox
(default: True).**init_kwargs – The key-word arguments to pass to init_fn
- Return type:
Union
[Variable
[TypeVar
(T
)],Variable
[AxisMetadata
[TypeVar
(T
)]]]- Returns:
A
flax.core.variables.Variable
that can be read or set via “.value” attribute. Throws an error if the variable exists already.