Examples
API
References
Setup all components required to train the network.
None
rng (jax.Array) –
neural_f (ott.solvers.nn.icnn.ICNN) –
neural_g (ott.solvers.nn.icnn.ICNN) –
input_dim (int) –
optimizer_f (Union[jax.Array, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) –
optimizer_g (Union[jax.Array, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) –