ott.neural.networks.potentials.BasePotential.variable#
- BasePotential.variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)#
- Overloads:
self, col (str), name (str), init_fn (Callable[…, T] | None), init_args → Variable[T]
self, col (str), name (str), init_fn (Callable[…, T] | None), init_args, unbox (Literal[True]), init_kwargs → Variable[T]
self, col (str), name (str), init_fn (Callable[…, T] | None), init_args, unbox (Literal[False]), init_kwargs → Variable[meta.AxisMetadata[T]]
self, col (str), name (str), init_fn (Callable[…, T] | None), init_args, unbox (bool), init_kwargs → Variable[T] | Variable[meta.AxisMetadata[T]]
- Parameters:
- Return type:
Variable[T] | Variable[AxisMetadata[T]]
Declares and returns a variable in this Module.
See
flax.core.variablesfor more information. See alsoparam()for a shorthand way to define read-only variables in the “params” collection.Contrary to
param(), all arguments passing usinginit_fnshould be passed on explicitly:>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... key = self.make_rng('stats') ... mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape) ... ... ... return x * mean.value >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}
In the example above, the function
lecun_normalexpects two arguments:keyandshape, and both have to be passed on. The PRNG forstatshas to be provided explicitly when callinginit()andapply().- Parameters:
col (
str) – The variable collection name.name (
str) – The variable name.init_fn (
Callable[...,TypeVar(T)] |None) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.*init_args – The positional arguments to pass to init_fn.
unbox (
bool) – If True,AxisMetadatainstances are replaced by their unboxed value, seeflax.nn.meta.unbox(default: True).**init_kwargs – The key-word arguments to pass to init_fn
- Returns:
A
flax.core.variables.Variablethat can be read or set via “.value” attribute. Throws an error if the variable exists already.- Return type:
Variable[T] | Variable[AxisMetadata[T]]