ott.neural.layers.PosDefPotentials.sow

ott.neural.layers.PosDefPotentials.sow#

PosDefPotentials.sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     self.sow('intermediates', 'h', h)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(variables, x, mutable=['intermediates'])
>>> print(state['intermediates'])
{'h': (Array([[-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ],
       [-1.503171  ,  0.7377704 , -0.59388214, -1.0079019 ]],      dtype=float32),)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

>>> class Foo2(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     init_fn = lambda: 0
...     reduce_fn = lambda a, b: a + b
...     self.sow('intermediates', 'h', x,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     self.sow('intermediates', 'h', x * 2,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     return x

>>> x = jnp.ones((1, 1))
>>> model = Foo2()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(
...     variables, x, mutable=['intermediates'])
>>> print(state['intermediates'])
{'h': Array([[3.]], dtype=float32)}
Parameters:
  • col (str) – The name of the variable collection.

  • name (str) – The name of the variable.

  • value (TypeVar(T)) – The value of the variable.

  • reduce_fn (Callable[[TypeVar(K), TypeVar(T)], TypeVar(K)]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn (Callable[[], TypeVar(K)]) – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Return type:

bool

Returns:

True if the value has been stored successfully, False otherwise.