ott.neural.networks.potentials.PotentialTrainState#
- class ott.neural.networks.potentials.PotentialTrainState(step, apply_fn, params, tx, opt_state, potential_value_fn, potential_gradient_fn)[source]#
Adds information about the model’s value and gradient to the state.
This extends
TrainStateto include the potential methods from theBasePotentialused during training.- Parameters:
potential_value_fn (
Callable[[FrozenDict[str,Array],Optional[Callable[[Array],Array]]],Callable[[Array],Array]]) – the potential’s value functionpotential_gradient_fn (
Callable[[FrozenDict[str,Array]],Callable[[Array],Array]]) – the potential’s gradient functionapply_fn (
Callable)params (
FrozenDict[str,Any])opt_state (
Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any, ArrayTree]])
Methods
apply_gradients(*, grads, **kwargs)Updates
step,params,opt_stateand**kwargsin return value.create(*, apply_fn, params, tx, **kwargs)Creates a new instance with
step=0and initializedopt_state.replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes