ott.tools.soft_sort.sort_with#

ott.tools.soft_sort.sort_with(inputs, criterion, topk=-1, **kwargs)[source]#

Sort a multidimensional array according to a real valued criterion.

Given batch vectors of dimension dim, to which, for each, a real value criterion is associated, this function produces topk (or batch if topk is set to -1, as by default) composite vectors of size dim that will be convex combinations of all vectors, ranked from smallest to largest criterion. Composite vectors with the largest indices will contain convex combinations of those vectors with highest criterion, vectors with smaller indices will contain combinations of vectors with smaller criterion.

Parameters:
  • inputs (Array) – the inputs as a jnp.ndarray[batch, dim].

  • criterion (Array) – the values according to which to sort the inputs. It has shape [batch, 1].

  • topk (int) – The number of outputs to keep.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn, used in PointCloud, that defines the ground cost function to transport from inputs to the num_targets target values (squared Euclidean distance by default, see pointcloud.py for more details); epsilon values as well as other parameters to shape the sinkhorn algorithm.

Return type:

Array

Returns:

A jnp.ndarray[batch | topk, dim].