- ott.tools.soft_sort.sort(inputs, axis=-1, topk=-1, num_targets=None, **kwargs)[source]#
Apply the soft sort operator on a given axis of the input.
Array) – jnp.ndarray<float> of any shape.
int) – the axis on which to apply the operator.
int) – if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft sorting.
int]) – if top-k is not specified, num_targets defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If not specified,
num_targetsis set by default to be the size of the slices of the input that are sorted, i.e. the number of composite sorted values is equal to that of the inputs that are sorted.
Any) – keyword arguments passed on to lower level functions. Of interest to the user are
squashing_fun, which will redistribute the values in
inputsto lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;
cost_fn, used in
PointCloud, that defines the ground cost function to transport from
num_targetstarget values (squared Euclidean distance by default, see
pointcloud.pyfor more details);
epsilonvalues as well as other parameters to shape the
- Return type:
A jnp.ndarray of the same shape as the input with soft sorted values on the given axis.