ott.solvers.nn.layers.PositiveDense.bind
ott.solvers.nn.layers.PositiveDense.bind#
- PositiveDense.bind(variables, *args, rngs=None, mutable=False)#
Creates an interactive Module instance by binding variables and RNGs.
bind
provides an “interactive” instance of a Module directly without transforming a function withapply
. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.Once the variables (and optionally RNGs) are bound to a
Module
it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs.bind()
should only be used for interactive experimentation, and in all other cases we strongly encourage users to useapply()
instead.Example:
import jax import jax.numpy as jnp import flax.linen as nn class AutoEncoder(nn.Module): def setup(self): self.encoder = nn.Dense(3) self.decoder = nn.Dense(5) def __call__(self, x): return self.decoder(self.encoder(x)) x = jnp.ones((16, 9)) ae = AutoEncoder() variables = ae.init(jax.random.PRNGKey(0), x) model = ae.bind(variables) z = model.encoder(x) x_reconstructed = model.decoder(z)
- Parameters
variables (
Mapping
[str
,Mapping
[str
,Any
]]) – A dictionary containing variables keyed by variable collections. Seeflax.core.variables
for more details about variables.*args – Named arguments (not used).
rngs (
Optional
[Dict
[str
,Any
]]) – a dict of PRNGKeys to initialize the PRNG sequences.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) –Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.
- Returns
A copy of this instance with bound variables and RNGs.