ott.neural.networks.potentials.BasePotential.sow#
- BasePotential.sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#
- Overloads:
self, col (str), name (str), value (Any) → bool
self, col (str), name (str), value (T), reduce_fn (Callable[[K, T], K]), init_fn (Callable[[], K]) → bool
- Parameters:
- Return type:
Stores a value in a collection.
Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.
If the target collection is not mutable
sowbehaves like a no-op and returnsFalse.Example:
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... self.sow('intermediates', 'h', h) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply(variables, x, mutable=['intermediates']) >>> jax.tree.map(jnp.shape, state['intermediates']) {'h': ((16, 4),)}
By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:
>>> class Foo2(nn.Module): ... @nn.compact ... def __call__(self, x): ... init_fn = lambda: 0 ... reduce_fn = lambda a, b: a + b ... self.sow('intermediates', 'h', x, ... init_fn=init_fn, reduce_fn=reduce_fn) ... self.sow('intermediates', 'h', x * 2, ... init_fn=init_fn, reduce_fn=reduce_fn) ... return x >>> x = jnp.ones((1, 1)) >>> model = Foo2() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply( ... variables, x, mutable=['intermediates']) >>> print(state['intermediates']) {'h': Array([[3.]], dtype=float32)}
- Parameters:
col (
str) – The name of the variable collection.name (
str) – The name of the variable.value (
TypeVar(T)) – The value of the variable.reduce_fn (
Callable[[TypeVar(K),TypeVar(T)],TypeVar(K)]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.init_fn (
Callable[[],TypeVar(K)]) – For the first value stored,reduce_fnwill be passed the result ofinit_fntogether with the value to be stored. The default is an empty tuple.
- Returns:
Trueif the value has been stored successfully,Falseotherwise.- Return type: