ott.utils.batched_vmap

Contents

ott.utils.batched_vmap#

ott.utils.batched_vmap(fun, *, batch_size, in_axes=0, out_axes=0)[source]#

Batched version of vmap().

Parameters:
  • fun (Callable[[ParamSpec(P)], TypeVar(R)]) – Function to be mapped over additional axes.

  • batch_size (int) – Size of the batch.

  • in_axes (Union[int, Sequence[int], Any, None]) – Values specifying which input array axes to map over.

  • out_axes (Any) – Values specifying where the mapped axis should appear in the output.

Return type:

Callable[[ParamSpec(P)], TypeVar(R)]

Returns:

The vectorized function.