ott.neural.networks.velocity_field.VelocityField.sow#
- VelocityField.sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#
Stores a value in a collection.
Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.
If the target collection is not mutable
sow
behaves like a no-op and returnsFalse
.Example:
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... self.sow('intermediates', 'h', h) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply(variables, x, mutable=['intermediates']) >>> jax.tree.map(jnp.shape, state['intermediates']) {'h': ((16, 4),)}
By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:
>>> class Foo2(nn.Module): ... @nn.compact ... def __call__(self, x): ... init_fn = lambda: 0 ... reduce_fn = lambda a, b: a + b ... self.sow('intermediates', 'h', x, ... init_fn=init_fn, reduce_fn=reduce_fn) ... self.sow('intermediates', 'h', x * 2, ... init_fn=init_fn, reduce_fn=reduce_fn) ... return x >>> x = jnp.ones((1, 1)) >>> model = Foo2() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply( ... variables, x, mutable=['intermediates']) >>> print(state['intermediates']) {'h': Array([[3.]], dtype=float32)}
- Parameters:
col (
str
) – The name of the variable collection.name (
str
) – The name of the variable.value (
TypeVar
(T
)) – The value of the variable.reduce_fn (
Callable
[[TypeVar
(K
),TypeVar
(T
)],TypeVar
(K
)]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.init_fn (
Callable
[[],TypeVar
(K
)]) – For the first value stored,reduce_fn
will be passed the result ofinit_fn
together with the value to be stored. The default is an empty tuple.
- Return type:
- Returns:
True
if the value has been stored successfully,False
otherwise.