ott.core.icnn.ICNN.bind#

ICNN.bind(variables, *args, rngs=None, mutable=False)#

Creates an interactive Module instance by binding variables and RNGs.

bind provides an “interactive” instance of a Module directly without transforming a function with apply. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.

Once the variables (and optionally RNGs) are bound to a Module it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. bind() should only be used for interactive experimentation, and in all other cases we strongly encourage users to use apply() instead.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = nn.Dense(3)
    self.decoder = nn.Dense(5)

  def __call__(self, x):
    return self.decoder(self.encoder(x))

x = jnp.ones((16, 9))
ae = AutoEncoder()
variables = ae.init(jax.random.PRNGKey(0), x)
model = ae.bind(variables)
z = model.encoder(x)
x_reconstructed = model.decoder(z)
Parameters
  • variables (Mapping[str, Mapping[str, Any]]) – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments (not used).

  • rngs (Optional[Dict[str, Any]]) – a dict of PRNGKeys to initialize the PRNG sequences.

  • mutable (Union[bool, str, Collection[str], DenyList]) –

    Can be bool, str, or list. Specifies which collections should be treated as mutable:

    bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

Returns

A copy of this instance with bound variables and RNGs.