- ott.tools.soft_sort.sort(inputs, axis=-1, topk=-1, num_targets=None, **kwargs)#
Apply the soft sort operator on a given axis of the input.
x = jax.random.uniform(rng, (100,)) x_sorted = sort(x)
will output sorted convex-combinations of values contained in
x, that are differentiable approximations to the sorted vector of entries in
x. These can be compared with the values produced by
x_sorted = jax.numpy.sort(x)
Array) – Array of any shape.
int) – the axis on which to apply the soft-sorting operator.
int) – if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft-sorting, since the number of target points to which the slice of the
inputstensor will be mapped to will be equal to
topk + 1.
int]) – if
topkis not specified, a vector of size``num_targets`` is returned. This defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If neither
num_targetsdefaults to the size of the slices of the input that are sorted, i.e.
inputs.shape[axis], and the number of composite sorted values is equal to the slice of the inputs that are sorted. As a result, the output is of the same size as
Any) – keyword arguments passed on to lower level functions. Of interest to the user are
squashing_fun, which will redistribute the values in
inputsto lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;
PointCloud, which defines the ground 1D cost function to transport from
epsilonregularization parameter. Remaining
kwargsare passed on to parameterize the
- Return type:
An Array of the same shape as the input, except on
axis, where that size will be equal to
num_targets, with soft-sorted values on the given axis. Same size as
inputsif both these parameters are