ott.solvers.nn.models.ICNN.param#

ICNN.param(name, init_fn, *init_args, unbox=True)#

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named “params”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args:

mean = self.param('mean', lecun_normal(), (2, 2))

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Parameters:
  • name (str) – The parameter name.

  • init_fn (Callable[..., TypeVar(T)]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.

  • *init_args – The arguments to pass to init_fn.

  • unbox (bool) – If True, AxisMetadata instances are replaced by their unboxed value, see flax.nn.meta.unbox (default: True).

Return type:

TypeVar(T)

Returns:

The value of the initialized parameter. Throws an error if the parameter exists already.