ott.neural.networks.potentials.PotentialTrainState.apply_gradients#
- PotentialTrainState.apply_gradients(*, grads, **kwargs)#
Updates
step,params,opt_stateand**kwargsin return value.Note that internally this function calls
.tx.update()followed by a call tooptax.apply_updates()to updateparamsandopt_state.- Parameters:
grads – Gradients that have the same pytree structure as
.params.**kwargs – Additional dataclass attributes that should be
.replace()-ed.
- Returns:
An updated instance of
selfwithstepincremented by one,paramsandopt_stateupdated by applyinggrads, and additional attributes replaced as specified bykwargs.