ott.tools.soft_sort.ranks#
- ott.tools.soft_sort.ranks(inputs, axis=-1, num_targets=None, target_weights=None, **kwargs)[source]#
Apply the soft rank operator on input tensor.
For instance:
x = jax.random.uniform(rng, (100,)) x_ranks = ranks(x)
will output values that are differentiable approximations to the ranks of entries in
x
. These should be compared to the non-differentiable rank vectors, namely the normalized inverse permutation produced byjax.numpy.argsort()
, which can be obtained as:x_ranks = jax.numpy.argsort(jax.numpy.argsort(x))
- Parameters:
inputs (
Array
) – Array of any shape.axis (
int
) – the axis on which to apply the soft-sorting operator.target_weights (
Optional
[Array
]) – This vector contains weights (summing to 1) that describe amount of mass shipped to targets.num_targets (
Optional
[int
]) – Iftarget_weights` is ``None
,num_targets
is considered to define the number of targets used to rank inputs. Each rank in the output will be a convex combination of{1, .., num_targets}
. The weight of each of these points is assumed to be uniform. If neithernum_targets
nortarget_weights
are specified,num_targets
defaults to the size of the slices of the input that are sorted, i.e.inputs.shape[axis]
.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
object ofPointCloud
, which defines the ground 1D cost function to transport frominputs
to thenum_targets
target values;epsilon
regularization parameter. Remainingkwargs
are passed on to parameterize theSinkhorn
solver.
- Return type:
- Returns:
An Array of the same shape as the input with soft-rank values normalized to be in \([0, n-1]\) where \(n\) is
inputs.shape[axis]
.