ott.neural.solvers.neuraldual.W2NeuralDual.train_neuraldual_parallel

ott.neural.solvers.neuraldual.W2NeuralDual.train_neuraldual_parallel#

W2NeuralDual.train_neuraldual_parallel(trainloader_source, trainloader_target, validloader_source, validloader_target, callback=None)[source]#

Training and validation with parallel updates.

Parameters:
Return type:

Dict[Literal['train_logs', 'valid_logs'], Dict[str, List[float]]]