ott.neural.networks.potentials.BasePotential.make_rng

ott.neural.networks.potentials.BasePotential.make_rng#

BasePotential.make_rng(name='params')#

Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility.

Note

If an invalid name is passed (i.e. no RNG key was passed by the user in .init or .apply for this name), then name will default to 'params'.

Example:

>>> import jax
>>> import flax.linen as nn

>>> class ParamsModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('params')
>>> class OtherModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('other')

>>> key = jax.random.key(0)
>>> params_out, _ = ParamsModule().init_with_output({'params': key})
>>> # self.make_rng('other') will default to using the 'params' RNG stream
>>> other_out, _ = OtherModule().init_with_output({'params': key})
>>> assert params_out == other_out

Learn more about RNG’s by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html

Parameters:

name (str) – The RNG sequence name.

Return type:

Array

Returns:

The newly generated RNG key.