ott.neural.networks.potentials.BasePotential.perturb#
- BasePotential.perturb(name, value, collection='perturbations')#
Add an zero-value variable (‘perturbation’) to the intermediate value.
The gradient of
valuewould be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients ofvalueby runningjax.gradon the perturbation argument.Note
This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.
Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = self.perturb('dense3', x) ... return nn.Dense(2)(x) >>> def loss(variables, inputs, targets): ... preds = model.apply(variables, inputs) ... return jnp.square(preds - targets).mean() >>> x = jnp.ones((2, 9)) >>> y = jnp.ones((2, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y) >>> print(intm_grads['perturbations']['dense3']) [[-0.04684732 0.06573904 -0.3194327 ] [-0.04684732 0.06573904 -0.3194327 ]]
If perturbations are not passed to
apply,perturbbehaves like a no-op so you can easily disable the behavior when not needed:>>> model.apply(variables, x) # works as expected Array([[-0.04579116, 0.50412744], [-0.04579116, 0.50412744]], dtype=float32) >>> model.apply({'params': variables['params']}, x) # behaves like a no-op Array([[-0.04579116, 0.50412744], [-0.04579116, 0.50412744]], dtype=float32) >>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y) >>> 'perturbations' not in intm_grads True