ott.solvers.nn.models.ICNN.sow#

ICNN.sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    h = nn.Dense(4)(x)
    self.sow('intermediates', 'h', h)
    return nn.Dense(2)(h)

x = jnp.ones((16, 9))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
y, state = model.apply(variables, x, mutable=['intermediates'])
print(state['intermediates'])  # {'h': (...,)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

class Foo2(nn.Module):
  @nn.compact
  def __call__(self, x):
    init_fn = lambda: 0
    reduce_fn = lambda a, b: a + b
    self.sow('intermediates', 'h', x,
             init_fn=init_fn, reduce_fn=reduce_fn)
    self.sow('intermediates', 'h', x * 2,
             init_fn=init_fn, reduce_fn=reduce_fn)
    return x

model = Foo2()
variables = model.init(jax.random.PRNGKey(0), x)
y, state = model.apply(variables, jnp.ones((1, 1)), mutable=['intermediates'])
print(state['intermediates'])  # ==> {'h': [[3.]]}
Parameters:
  • col (str) – The name of the variable collection.

  • name (str) – The name of the variable.

  • value (TypeVar(T)) – The value of the variable.

  • reduce_fn (Callable[[TypeVar(K), TypeVar(T)], TypeVar(K)]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn (Callable[[], TypeVar(K)]) – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Return type:

bool

Returns:

True if the value has been stored successfully, False otherwise.