ott.neural.networks.potentials.PotentialMLP.make_rng#
- PotentialMLP.make_rng(name='params')#
Returns a new RNG key from a given RNG sequence for this Module.
The new RNG key is split from the previous one. Thus, every call to
make_rng
returns a new RNG key, while still guaranteeing full reproducibility.Note
If an invalid name is passed (i.e. no RNG key was passed by the user in
.init
or.apply
for this name), thenname
will default to'params'
.Example:
>>> import jax >>> import flax.linen as nn >>> class ParamsModule(nn.Module): ... def __call__(self): ... return self.make_rng('params') >>> class OtherModule(nn.Module): ... def __call__(self): ... return self.make_rng('other') >>> key = jax.random.key(0) >>> params_out, _ = ParamsModule().init_with_output({'params': key}) >>> # self.make_rng('other') will default to using the 'params' RNG stream >>> other_out, _ = OtherModule().init_with_output({'params': key}) >>> assert params_out == other_out
Learn more about RNG’s by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html