Examples
API
References
Create initial TrainState.
TrainState
rng (jax.Array) –
optimizer (Union[jax.Array, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) –
input (Union[int, Tuple[int, ...]]) –