ott.tools.soft_sort.topk_mask#
- ott.tools.soft_sort.topk_mask(inputs, axis=-1, k=1, **kwargs)[source]#
Soft \(\text{top-}k\) selection mask.
For instance:
k = 5 x = jax.random.uniform(rng, (100,)) mask = topk_mask(x, k=k)
will output a vector of shape
x.shape
, with values in \([0,1]\), that are differentiable approximations to the binary mask selecting the top $k$ entries inx
. These should be compared to the non-differentiable mask obtained withjax.numpy.sort()
, which can be obtained as:mask = x >= jax.numpy.sort(x).flip()[k-1]
- Parameters:
inputs (
Array
) – Array of any shape.axis (
int
) – the axis on which to apply the soft-sorting operator.k (
int
) – topk parameter. Should be smaller thaninputs.shape[axis]
.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
object ofPointCloud
, which defines the ground 1D cost function to transport frominputs
to thenum_targets
target values;epsilon
regularization parameter. Remainingkwargs
are passed on to parameterize theSinkhorn
solver.
- Return type:
- Returns:
The soft \(\text{top-}k\) selection mask.