ott.solvers.nn.models.MLP.create_train_state#

MLP.create_train_state(rng, optimizer, input, **kwargs)#

Create initial training state.

Parameters:
Return type:

NeuralTrainState