ott.neural.networks.icnn.ICNN.variable

Contents

ott.neural.networks.icnn.ICNN.variable#

ICNN.variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)#

Declares and returns a variable in this Module.

See flax.core.variables for more information. See also param() for a shorthand way to define read-only variables in the “params” collection.

Contrary to param(), all arguments passing using init_fn should be passed on explicitly:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     key = self.make_rng('stats')
...     mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape)
...     ...
...     return x * mean.value
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling init() and apply().

Parameters:
  • col (str) – The variable collection name.

  • name (str) – The variable name.

  • init_fn (Optional[Callable[..., TypeVar(T)]]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.

  • *init_args – The positional arguments to pass to init_fn.

  • unbox (bool) – If True, AxisMetadata instances are replaced by their unboxed value, see flax.nn.meta.unbox (default: True).

  • **init_kwargs – The key-word arguments to pass to init_fn

Return type:

Union[Variable[TypeVar(T)], Variable[AxisMetadata[TypeVar(T)]]]

Returns:

A flax.core.variables.Variable that can be read or set via “.value” attribute. Throws an error if the variable exists already.