ott.neural.networks.potentials.MLP.perturb

Contents

ott.neural.networks.potentials.MLP.perturb#

MLP.perturb(name, value, collection='perturbations')#

Add an zero-value variable (‘perturbation’) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation argument.

Note

This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = self.perturb('dense3', x)
...     return nn.Dense(2)(x)

>>> def loss(variables, inputs, targets):
...   preds = model.apply(variables, inputs)
...   return jnp.square(preds - targets).mean()

>>> x = jnp.ones((2, 9))
>>> y = jnp.ones((2, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
>>> print(intm_grads['perturbations']['dense3'])
[[-1.456924   -0.44332537  0.02422847]
 [-1.456924   -0.44332537  0.02422847]]

If perturbations are not passed to apply, perturb behaves like a no-op so you can easily disable the behavior when not needed:

>>> model.apply(variables, x) # works as expected
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
>>> 'perturbations' not in intm_grads
True
Parameters:
Return type:

TypeVar(T)