ott.tools.soft_sort.topk_mask#
- ott.tools.soft_sort.topk_mask(inputs, axis=-1, k=1, **kwargs)[source]#
Soft \(\text{top-}k\) selection mask.
For instance:
k = 5 x = jax.random.uniform(rng, (100,)) mask = topk_mask(x, k=k)
will output a vector of shape
x.shape, with values in \([0,1]\), that are differentiable approximations to the binary mask selecting the top $k$ entries inx. These should be compared to the non-differentiable mask obtained withjax.numpy.sort(), which can be obtained as:mask = x >= jax.numpy.sort(x).flip()[k-1]
- Parameters:
inputs (
Array) – Array of any shape.axis (
int) – the axis on which to apply the soft-sorting operator.k (
int) – topk parameter. Should be smaller thaninputs.shape[axis].kwargs (
Any) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun, which will redistribute the values ininputsto lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fnobject ofPointCloud, which defines the ground 1D cost function to transport frominputsto thenum_targetstarget values;epsilonregularization parameter. Remainingkwargsare passed on to parameterize theSinkhornsolver.
- Return type:
- Returns:
The soft \(\text{top-}k\) selection mask.