ott.solvers.nn.neuraldual.NeuralDualSolver.train_neuraldual#

NeuralDualSolver.train_neuraldual(trainloader_source, trainloader_target, validloader_source, validloader_target)[source]#

Implementation of the training and validation script.

Return type

Dict[Literal[‘train_logs’, ‘valid_logs’], Dict[str, List[float]]]

Parameters
  • trainloader_source (Iterable[jax.Array]) –

  • trainloader_target (Iterable[jax.Array]) –

  • validloader_source (Iterable[jax.Array]) –

  • validloader_target (Iterable[jax.Array]) –