ott.solvers.nn.models.ICNN.sow#
- ICNN.sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#
Stores a value in a collection.
Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.
If the target collection is not mutable sow behaves like a no-op and returns False.
Example:
import jax import jax.numpy as jnp import flax.linen as nn class Foo(nn.Module): @nn.compact def __call__(self, x): h = nn.Dense(4)(x) self.sow('intermediates', 'h', h) return nn.Dense(2)(h) x = jnp.ones((16, 9)) model = Foo() variables = model.init(jax.random.PRNGKey(0), x) y, state = model.apply(variables, x, mutable=['intermediates']) print(state['intermediates']) # {'h': (...,)}
By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:
class Foo2(nn.Module): @nn.compact def __call__(self, x): init_fn = lambda: 0 reduce_fn = lambda a, b: a + b self.sow('intermediates', 'h', x, init_fn=init_fn, reduce_fn=reduce_fn) self.sow('intermediates', 'h', x * 2, init_fn=init_fn, reduce_fn=reduce_fn) return x model = Foo2() variables = model.init(jax.random.PRNGKey(0), x) y, state = model.apply(variables, jnp.ones((1, 1)), mutable=['intermediates']) print(state['intermediates']) # ==> {'h': [[3.]]}
- Parameters:
col (
str
) – The name of the variable collection.name (
str
) – The name of the variable.value (
TypeVar
(T
)) – The value of the variable.reduce_fn (
Callable
[[TypeVar
(K
),TypeVar
(T
)],TypeVar
(K
)]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.init_fn (
Callable
[[],TypeVar
(K
)]) – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.
- Return type:
- Returns:
True if the value has been stored successfully, False otherwise.