ott.tools.soft_sort.sort#
- ott.tools.soft_sort.sort(inputs, axis=-1, topk=-1, num_targets=None, **kwargs)[source]#
Apply the soft sort operator on a given axis of the input.
- Parameters:
inputs (
Array
) – jnp.ndarray<float> of any shape.axis (
int
) – the axis on which to apply the operator.topk (
int
) – if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft sorting.num_targets (
Optional
[int
]) – if top-k is not specified, num_targets defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If not specified,num_targets
is set by default to be the size of the slices of the input that are sorted, i.e. the number of composite sorted values is equal to that of the inputs that are sorted.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
, used inPointCloud
, that defines the ground cost function to transport frominputs
to thenum_targets
target values (squared Euclidean distance by default, seepointcloud.py
for more details);epsilon
values as well as other parameters to shape thesinkhorn
algorithm.
- Return type:
- Returns:
A jnp.ndarray of the same shape as the input with soft sorted values on the given axis.