, axis=-1, k=1, **kwargs)[source]#

Soft \(\text{top-}k\) selection mask.

For instance:

k = 5
x = jax.random.uniform(rng, (100,))
mask = topk_mask(x, k=k)

will output a vector of shape x.shape, with values in \([0,1]\), that are differentiable approximations to the binary mask selecting the top $k$ entries in x. These should be compared to the non-differentiable mask obtained with jax.numpy.sort(), which can be obtained as:

mask = x >= jax.numpy.sort(x).flip()[k-1]
  • inputs (Array) – Array of any shape.

  • axis (int) – the axis on which to apply the soft-sorting operator.

  • k (int) – topk parameter. Should be smaller than inputs.shape[axis].

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn object of PointCloud, which defines the ground 1D cost function to transport from inputs to the num_targets target values; epsilon regularization parameter. Remaining kwargs are passed on to parameterize the Sinkhorn solver.

Return type:



The soft \(\text{top-}k\) selection mask.