ott.utils.batched_vmap#
- ott.utils.batched_vmap(fun, *, batch_size, in_axes=0, out_axes=0)[source]#
Batched version of
vmap().- Parameters:
fun (
Callable[[ParamSpec(P, bound=None)],TypeVar(R)]) – Function to be mapped over additional axes.batch_size (
int) – Size of the batch.in_axes (
Union[int,Sequence[int],Any,None]) – Values specifying which input array axes to map over.out_axes (
Any) – Values specifying where the mapped axis should appear in the output.
- Return type:
- Returns:
The vectorized function.