ott.tools.soft_sort.sort#

ott.tools.soft_sort.sort(inputs, axis=-1, topk=-1, num_targets=None, **kwargs)[source]#

Apply the soft sort operator on a given axis of the input.

Parameters:
  • inputs (Array) – jnp.ndarray<float> of any shape.

  • axis (int) – the axis on which to apply the operator.

  • topk (int) – if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft sorting.

  • num_targets (Optional[int]) – if top-k is not specified, num_targets defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If not specified, num_targets is set by default to be the size of the slices of the input that are sorted, i.e. the number of composite sorted values is equal to that of the inputs that are sorted.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn, used in PointCloud, that defines the ground cost function to transport from inputs to the num_targets target values (squared Euclidean distance by default, see pointcloud.py for more details); epsilon values as well as other parameters to shape the sinkhorn algorithm.

Return type:

Array

Returns:

A jnp.ndarray of the same shape as the input with soft sorted values on the given axis.