ott.tools.soft_sort.ranks#
- ott.tools.soft_sort.ranks(inputs, axis=-1, num_targets=None, target_weights=None, **kwargs)[source]#
Apply the soft rank operator on input tensor.
For instance:
x = jax.random.uniform(rng, (100,)) x_ranks = ranks(x)
will output values that are differentiable approximations to the ranks of entries in
x. These should be compared to the non-differentiable rank vectors, namely the normalized inverse permutation produced byjax.numpy.argsort(), which can be obtained as:x_ranks = jax.numpy.argsort(jax.numpy.argsort(x))
- Parameters:
inputs (
Array) – Array of any shape.axis (
int) – the axis on which to apply the soft-sorting operator.target_weights (
Optional[Array]) – This vector contains weights (summing to 1) that describe amount of mass shipped to targets.num_targets (
Optional[int]) – Iftarget_weights` is ``None,num_targetsis considered to define the number of targets used to rank inputs. Each rank in the output will be a convex combination of{1, .., num_targets}. The weight of each of these points is assumed to be uniform. If neithernum_targetsnortarget_weightsare specified,num_targetsdefaults to the size of the slices of the input that are sorted, i.e.inputs.shape[axis].kwargs (
Any) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun, which will redistribute the values ininputsto lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fnobject ofPointCloud, which defines the ground 1D cost function to transport frominputsto thenum_targetstarget values;epsilonregularization parameter. Remainingkwargsare passed on to parameterize theSinkhornsolver.
- Return type:
- Returns:
An Array of the same shape as the input with soft-rank values normalized to be in \([0, n-1]\) where \(n\) is
inputs.shape[axis].