- MetaMLP.bind(variables, *args, rngs=None, mutable=False)#
Creates an interactive Module instance by binding variables and RNGs.
bindprovides an “interactive” instance of a Module directly without transforming a function with
apply. This is particularly useful for debugging zand interactive use cases like notebooks where a function would limit the ability to split up code into different cells.
Once the variables (and optionally RNGs) are bound to a
Moduleit becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs.
bind()should only be used for interactive experimentation, and in all other cases we strongly encourage users to use
import jax import jax.numpy as jnp import flax.linen as nn class AutoEncoder(nn.Module): def setup(self): self.encoder = nn.Dense(3) self.decoder = nn.Dense(5) def __call__(self, x): return self.decoder(self.encoder(x)) x = jnp.ones((16, 9)) ae = AutoEncoder() variables = ae.init(jax.random.PRNGKey(0), x) model = ae.bind(variables) z = model.encoder(x) x_reconstructed = model.decoder(z)
*args – Named arguments (not used).
Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool: all/no collections are mutable.
str: The name of a single mutable collection.
list: A list of names of mutable collections.
self (flax.linen.module.M) –
- Return type
M, bound= Module)
A copy of this instance with bound variables and RNGs.