ott.tools.soft_sort.sort#
- ott.tools.soft_sort.sort(inputs, axis=-1, topk=-1, num_targets=None, **kwargs)[source]#
Apply the soft sort operator on a given axis of the input.
For instance:
x = jax.random.uniform(rng, (100,)) x_sorted = sort(x)
will output sorted convex-combinations of values contained in
x, that are differentiable approximations to the sorted vector of entries inx. These can be compared with the values produced byjax.numpy.sort(),x_sorted = jax.numpy.sort(x)
- Parameters:
inputs (
Array) – Array of any shape.axis (
int) – the axis on which to apply the soft-sorting operator.topk (
int) – if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft-sorting, since the number of target points to which the slice of theinputstensor will be mapped to will be equal totopk + 1.num_targets (
Optional[int]) – iftopkis not specified, a vector of size``num_targets`` is returned. This defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If neithertopknornum_targetsare specified,num_targetsdefaults to the size of the slices of the input that are sorted, i.e.inputs.shape[axis], and the number of composite sorted values is equal to the slice of the inputs that are sorted. As a result, the output is of the same size asinputs.kwargs (
Any) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun, which will redistribute the values ininputsto lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fnobject ofPointCloud, which defines the ground 1D cost function to transport frominputsto thenum_targetstarget values;epsilonregularization parameter. Remainingkwargsare passed on to parameterize theSinkhornsolver.
- Return type:
- Returns:
An Array of the same shape as the input, except on
axis, where that size will be equal totopkornum_targets, with soft-sorted values on the given axis. Same size asinputsif both these parameters areNone.