ott.neural.networks.potentials.BasePotential.param#
- BasePotential.param(name, init_fn, *init_args, unbox=True, **init_kwargs)#
- Overloads:
self, name (str), init_fn (Callable[…, T]), init_args → T
self, name (str), init_fn (Callable[…, meta.AxisMetadata[T]] | Callable[…, T]), init_args, unbox (Literal[True]), init_kwargs → T
self, name (str), init_fn (Callable[…, T]), init_args, unbox (Literal[False]), init_kwargs → T
self, name (str), init_fn (Callable[…, T | meta.AxisMetadata[T]]), init_args, unbox (bool), init_kwargs → T | meta.AxisMetadata[T]
- Parameters:
- Return type:
T | AxisMetadata[T]
Declares and returns a parameter in this Module.
Parameters are read-only variables in the collection named “params”. See
flax.core.variablesfor more details on variables.The first argument of
init_fnis assumed to be a PRNG key, which is provided automatically and does not have to be passed usinginit_argsorinit_kwargs:>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... mean = self.param('mean', nn.initializers.lecun_normal(), x.shape) ... ... ... return x * mean >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}
In the example above, the function
lecun_normalexpects two arguments:keyandshape, but onlyshapehas to be provided explicitly;keyis set automatically using the PRNG forparamsthat is passed when initializing the module usinginit().- Parameters:
name (
str) – The parameter name.init_fn (
Callable[...,Union[TypeVar(T),AxisMetadata[TypeVar(T)]]]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.*init_args – The positional arguments to pass to init_fn.
unbox (
bool) – If True,AxisMetadatainstances are replaced by their unboxed value, seeflax.nn.meta.unbox(default: True).**init_kwargs – The key-word arguments to pass to init_fn.
- Returns:
The value of the initialized parameter. Throws an error if the parameter exists already.
- Return type:
T | AxisMetadata[T]