ott.neural.solvers.neuraldual.BaseW2NeuralDual.unbind

ott.neural.solvers.neuraldual.BaseW2NeuralDual.unbind#

BaseW2NeuralDual.unbind()#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

>>> class Encoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(256)(x)

>>> class Decoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(784)(x)

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = Encoder()
...     self.decoder = Decoder()
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> module = AutoEncoder()
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 784)))

>>> # Extract the Encoder sub-Module and its variables
>>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
Return type:

Tuple[TypeVar(M, bound= Module), Mapping[str, Mapping[str, Any]]]

Returns:

A tuple with an unbound copy of this Module and its variables.

Parameters:

self (TypeVar(M, bound= Module)) –