ott.neural.networks.potentials.PotentialTrainState.apply_gradients

ott.neural.networks.potentials.PotentialTrainState.apply_gradients#

PotentialTrainState.apply_gradients(*, grads, **kwargs)#

Updates step, params, opt_state and **kwargs in return value.

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters:
  • grads – Gradients that have the same pytree structure as .params.

  • **kwargs – Additional dataclass attributes that should be .replace()-ed.

Returns:

An updated instance of self with step incremented by one, params and opt_state updated by applying grads, and additional attributes replaced as specified by kwargs.