ott.neural.networks.potentials.PotentialTrainState

ott.neural.networks.potentials.PotentialTrainState#

class ott.neural.networks.potentials.PotentialTrainState(step, apply_fn, params, tx, opt_state, potential_value_fn, potential_gradient_fn)[source]#

Adds information about the model’s value and gradient to the state.

This extends TrainState to include the potential methods from the BasePotential used during training.

Parameters:

Methods

apply_gradients(*, grads, **kwargs)

Updates step, params, opt_state and **kwargs in return value.

create(*, apply_fn, params, tx, **kwargs)

Creates a new instance with step=0 and initialized opt_state.

replace(**updates)

"Returns a new object replacing the specified fields with new values.

Attributes

potential_value_fn

potential_gradient_fn

step

apply_fn

params

tx

opt_state