ott.neural.networks.potentials.BasePotential.unbind#
- BasePotential.unbind()#
Returns an unbound copy of a Module and its variables.
unbind
helps create a stateless version of a bound Module.An example of a common use case: to extract a sub-Module defined inside
setup()
and its corresponding variables: 1) temporarilybind
the parent Module; and then 2)unbind
the desired sub-Module. (Recall thatsetup()
is only called when the Module is bound.):>>> class Encoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(256)(x) >>> class Decoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(784)(x) >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = Encoder() ... self.decoder = Decoder() ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> module = AutoEncoder() >>> variables = module.init(jax.random.key(0), jnp.ones((1, 784))) >>> # Extract the Encoder sub-Module and its variables >>> encoder, encoder_vars = module.bind(variables).encoder.unbind()