ott.neural.networks.potentials.PotentialMLP.sow#
- PotentialMLP.sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#
Store intermediate values during module execution for later extraction.
Used with
nnx.capture()decorator to collect intermediate values without explicitly passing containers through module calls. Values are stored under the specifiednamein a collection associated withvariable_type.By default, values are appended to a tuple, allowing multiple values to be tracked when the same module is called multiple times.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'features', x) ... x = self.linear2(x) ... return x >>> # With the capture decorator, sow returns intermediates >>> model = Model(rngs=nnx.Rngs(0)) >>> @nnx.capture(nnx.Intermediate) ... def forward(model, x): ... return model(x) >>> result, intermediates = forward(model, jnp.ones(2)) >>> assert 'features' in intermediates
Custom init/reduce functions can be passed to control accumulation:
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... return x
- Parameters:
variable_type (
type[Variable[TypeVar(B)]] |str) – TheVariabletype for the stored value. TypicallyIntermediateor a subclass is used.name (
str) – A string key for storing the value in the collection.value (
TypeVar(A)) – The value to be stored.reduce_fn (
Callable[[TypeVar(B),TypeVar(A)],TypeVar(B)]) – Function to combine existing and new values. Default appends to a tuple.init_fn (
Callable[[],TypeVar(B)]) – Function providing initial value for firstreduce_fncall. Default is an empty tuple.
- Return type: