ott.neural.networks.potentials.PotentialTrainState#
- class ott.neural.networks.potentials.PotentialTrainState(step, apply_fn, params, tx, opt_state, potential_value_fn, potential_gradient_fn)[source]#
Adds information about the model’s value and gradient to the state.
This extends
TrainState
to include the potential methods from theBasePotential
used during training.- Parameters:
potential_value_fn (
Callable
[[FrozenDict
[str
,Array
],Optional
[Callable
[[Array
],Array
]]],Callable
[[Array
],Array
]]) – the potential’s value functionpotential_gradient_fn (
Callable
[[FrozenDict
[str
,Array
]],Callable
[[Array
],Array
]]) – the potential’s gradient functionapply_fn (
Callable
)params (
FrozenDict
[str
,Any
])opt_state (
Union
[Array
,ndarray
,bool
,number
,Iterable
[ArrayTree],Mapping
[Any
, ArrayTree]])
Methods
apply_gradients
(*, grads, **kwargs)Updates
step
,params
,opt_state
and**kwargs
in return value.create
(*, apply_fn, params, tx, **kwargs)Creates a new instance with
step=0
and initializedopt_state
.replace
(**updates)"Returns a new object replacing the specified fields with new values.
Attributes