ott.tools.soft_sort.ranks#
- ott.tools.soft_sort.ranks(inputs, axis=-1, num_targets=None, **kwargs)[source]#
Apply the soft rank operator on input tensor.
- Parameters:
inputs (
Array
) – a jnp.ndarray<float> of any shape.axis (
int
) – the axis on which to apply the soft ranks operator.num_targets (
Optional
[int
]) – num_targets defines the number of targets used to compute a composite ranks for each value ininputs
: that soft rank will be a convex combination of values in [0,…,``(num_targets-2)/num_targets``,1] specified by the optimal transport between values ininputs
towards those values. If not specified,num_targets
is set by default to be the size of the slices of the input that are sorted.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
, used inPointCloud
, that defines the ground cost function to transport frominputs
to thenum_targets
target values (squared Euclidean distance by default, seepointcloud.py
for more details);epsilon
values as well as other parameters to shape thesinkhorn
algorithm.
- Return type:
- Returns:
A jnp.ndarray<float> of the same shape as inputs, with the ranks.