ott.neural.networks.potentials.PotentialTrainState.apply_gradients#
- PotentialTrainState.apply_gradients(*, grads, **kwargs)#
Updates
step
,params
,opt_state
and**kwargs
in return value.Note that internally this function calls
.tx.update()
followed by a call tooptax.apply_updates()
to updateparams
andopt_state
.- Parameters:
grads – Gradients that have the same pytree structure as
.params
.**kwargs – Additional dataclass attributes that should be
.replace()
-ed.
- Returns:
An updated instance of
self
withstep
incremented by one,params
andopt_state
updated by applyinggrads
, and additional attributes replaced as specified bykwargs
.