, axis=-1, topk=-1, num_targets=None, **kwargs)[source]#

Apply the soft sort operator on a given axis of the input.

For instance:

x = jax.random.uniform(rng, (100,))
x_sorted = sort(x)

will output sorted convex-combinations of values contained in x, that are differentiable approximations to the sorted vector of entries in x. These can be compared with the values produced by jax.numpy.sort(),

x_sorted = jax.numpy.sort(x)
  • inputs (Array) – Array of any shape.

  • axis (int) – the axis on which to apply the soft-sorting operator.

  • topk (int) – if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft-sorting, since the number of target points to which the slice of the inputs tensor will be mapped to will be equal to topk + 1.

  • num_targets (Optional[int]) – if topk is not specified, a vector of size``num_targets`` is returned. This defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If neither topk nor num_targets are specified, num_targets defaults to the size of the slices of the input that are sorted, i.e. inputs.shape[axis], and the number of composite sorted values is equal to the slice of the inputs that are sorted. As a result, the output is of the same size as inputs.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn object of PointCloud, which defines the ground 1D cost function to transport from inputs to the num_targets target values; epsilon regularization parameter. Remaining kwargs are passed on to parameterize the Sinkhorn solver.

Return type:



An Array of the same shape as the input, except on axis, where that size will be equal to topk or num_targets, with soft-sorted values on the given axis. Same size as inputs if both these parameters are None.