ott.solvers.nn.models.MLP.unbind#

MLP.unbind()#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder()
    self.decoder = Decoder()

  def __call__(self, x):
    return self.decoder(self.encoder(x))

module = AutoEncoder()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
...
# Extract the Encoder sub-Module and its variables
encoder, encoder_vars = module.bind(variables).encoder.unbind()
Return type:

Tuple[TypeVar(M, bound= Module), Mapping[str, Mapping[str, Any]]]

Returns:

A tuple with an unbound copy of this Module and its variables.

Parameters:

self (TypeVar(M, bound= Module)) –