ott.solvers.nn.models.MLP.unbind#
- MLP.unbind()#
Returns an unbound copy of a Module and its variables.
unbind
helps create a stateless version of a bound Module.An example of a common use case: to extract a sub-Module defined inside
setup()
and its corresponding variables: 1) temporarilybind
the parent Module; and then 2)unbind
the desired sub-Module. (Recall thatsetup()
is only called when the Module is bound.):class AutoEncoder(nn.Module): def setup(self): self.encoder = Encoder() self.decoder = Decoder() def __call__(self, x): return self.decoder(self.encoder(x)) module = AutoEncoder() variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 784))) ... # Extract the Encoder sub-Module and its variables encoder, encoder_vars = module.bind(variables).encoder.unbind()