ott.neural.networks.velocity_field.VelocityField.create_train_state

ott.neural.networks.velocity_field.VelocityField.create_train_state#

VelocityField.create_train_state(rng, optimizer, input_dim, condition_dim=None)[source]#

Create the training state.

Parameters:
Return type:

TrainState

Returns:

The training state.