ott.neural.solvers.neuraldual.BaseW2NeuralDual.param

ott.neural.solvers.neuraldual.BaseW2NeuralDual.param#

BaseW2NeuralDual.param(name, init_fn, *init_args, unbox=True)#

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named “params”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     mean = self.param('mean', nn.initializers.lecun_normal(), x.shape)
...     ...
...     return x * mean
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Parameters:
  • name (str) – The parameter name.

  • init_fn (Callable[..., TypeVar(T)]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.

  • *init_args – The arguments to pass to init_fn.

  • unbox (bool) – If True, AxisMetadata instances are replaced by their unboxed value, see flax.nn.meta.unbox (default: True).

Return type:

Union[TypeVar(T), AxisMetadata[TypeVar(T)]]

Returns:

The value of the initialized parameter. Throws an error if the parameter exists already.