ott.utils.batched_vmap#
- ott.utils.batched_vmap(fun, *, batch_size, in_axes=0, out_axes=0)[source]#
Batched version of
vmap()
.- Parameters:
fun (
Callable
[[ParamSpec
(P
)],TypeVar
(R
)]) – Function to be mapped over additional axes.batch_size (
int
) – Size of the batch.in_axes (
Union
[int
,Sequence
[int
],Any
,None
]) – Values specifying which input array axes to map over.out_axes (
Any
) – Values specifying where the mapped axis should appear in the output.
- Return type:
- Returns:
The vectorized function.