ott.tools.soft_sort.sort#
- ott.tools.soft_sort.sort(inputs, axis=-1, topk=-1, num_targets=None, **kwargs)[source]#
Apply the soft sort operator on a given axis of the input.
For instance:
x = jax.random.uniform(rng, (100,)) x_sorted = sort(x)
will output sorted convex-combinations of values contained in
x
, that are differentiable approximations to the sorted vector of entries inx
. These can be compared with the values produced byjax.numpy.sort()
,x_sorted = jax.numpy.sort(x)
- Parameters:
inputs (
Array
) – Array of any shape.axis (
int
) – the axis on which to apply the soft-sorting operator.topk (
int
) – if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft-sorting, since the number of target points to which the slice of theinputs
tensor will be mapped to will be equal totopk + 1
.num_targets (
Optional
[int
]) – iftopk
is not specified, a vector of size``num_targets`` is returned. This defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If neithertopk
nornum_targets
are specified,num_targets
defaults to the size of the slices of the input that are sorted, i.e.inputs.shape[axis]
, and the number of composite sorted values is equal to the slice of the inputs that are sorted. As a result, the output is of the same size asinputs
.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
object ofPointCloud
, which defines the ground 1D cost function to transport frominputs
to thenum_targets
target values;epsilon
regularization parameter. Remainingkwargs
are passed on to parameterize theSinkhorn
solver.
- Return type:
- Returns:
An Array of the same shape as the input, except on
axis
, where that size will be equal totopk
ornum_targets
, with soft-sorted values on the given axis. Same size asinputs
if both these parameters areNone
.