# ott.tools.soft_sort.ranks#

ott.tools.soft_sort.ranks(inputs, axis=-1, num_targets=None, target_weights=None, **kwargs)[source]#

Apply the soft rank operator on input tensor.

For instance:

x = jax.random.uniform(rng, (100,))
x_ranks = ranks(x)


will output values that are differentiable approximations to the ranks of entries in x. These should be compared to the non-differentiable rank vectors, namely the normalized inverse permutation produced by jax.numpy.argsort(), which can be obtained as:

x_ranks = jax.numpy.argsort(jax.numpy.argsort(x))

Parameters:
• inputs (Array) – Array of any shape.

• axis (int) – the axis on which to apply the soft-sorting operator.

• target_weights (Optional[Array]) – This vector contains weights (summing to 1) that describe amount of mass shipped to targets.

• num_targets (Optional[int]) – If target_weights  is None, num_targets is considered to define the number of targets used to rank inputs. Each rank in the output will be a convex combination of {1, .., num_targets}. The weight of each of these points is assumed to be uniform. If neither num_targets nor target_weights are specified, num_targets defaults to the size of the slices of the input that are sorted, i.e. inputs.shape[axis].

• kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in $$[0,1]$$ (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn object of PointCloud, which defines the ground 1D cost function to transport from inputs to the num_targets target values; epsilon regularization parameter. Remaining kwargs are passed on to parameterize the Sinkhorn solver.

Return type:

Array

Returns:

An Array of the same shape as the input with soft-rank values normalized to be in $$[0, n-1]$$ where $$n$$ is inputs.shape[axis]`.