Contents, axis=-1, num_targets=None, target_weights=None, **kwargs)[source]#

Apply the soft rank operator on input tensor.

For instance:

x = jax.random.uniform(rng, (100,))
x_ranks = ranks(x)

will output values that are differentiable approximations to the ranks of entries in x. These should be compared to the non-differentiable rank vectors, namely the normalized inverse permutation produced by jax.numpy.argsort(), which can be obtained as:

x_ranks = jax.numpy.argsort(jax.numpy.argsort(x))
  • inputs (Array) – Array of any shape.

  • axis (int) – the axis on which to apply the soft-sorting operator.

  • target_weights (Optional[Array]) – This vector contains weights (summing to 1) that describe amount of mass shipped to targets.

  • num_targets (Optional[int]) – If target_weights`  is ``None, num_targets is considered to define the number of targets used to rank inputs. Each rank in the output will be a convex combination of {1, .., num_targets}. The weight of each of these points is assumed to be uniform. If neither num_targets nor target_weights are specified, num_targets defaults to the size of the slices of the input that are sorted, i.e. inputs.shape[axis].

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn object of PointCloud, which defines the ground 1D cost function to transport from inputs to the num_targets target values; epsilon regularization parameter. Remaining kwargs are passed on to parameterize the Sinkhorn solver.

Return type:



An Array of the same shape as the input with soft-rank values normalized to be in \([0, n-1]\) where \(n\) is inputs.shape[axis].