ott.neural.networks.potentials.BasePotential.make_rng#
- BasePotential.make_rng(name='params')#
Returns a new RNG key from a given RNG sequence for this Module.
The new RNG key is split from the previous one. Thus, every call to
make_rngreturns a new RNG key, while still guaranteeing full reproducibility.Note
If an invalid name is passed (i.e. no RNG key was passed by the user in
.initor.applyfor this name), thennamewill default to'params'.Example:
>>> import jax >>> import flax.linen as nn >>> class ParamsModule(nn.Module): ... def __call__(self): ... return self.make_rng('params') >>> class OtherModule(nn.Module): ... def __call__(self): ... return self.make_rng('other') >>> key = jax.random.key(0) >>> params_out, _ = ParamsModule().init_with_output({'params': key}) >>> # self.make_rng('other') will default to using the 'params' RNG stream >>> other_out, _ = OtherModule().init_with_output({'params': key}) >>> assert params_out == other_out
Learn more about RNG’s by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html