ott.tools.soft_sort.sort_with#
- ott.tools.soft_sort.sort_with(inputs, criterion, topk=-1, **kwargs)[source]#
Sort a multidimensional array according to a real valued criterion.
Given
batch
vectors of dimension dim, to which, for each, a real valuecriterion
is associated, this function producestopk
(orbatch
iftopk
is set to -1, as by default) composite vectors of sizedim
that will be convex combinations of all vectors, ranked from smallest to largest criterion. Composite vectors with the largest indices will contain convex combinations of those vectors with highest criterion, vectors with smaller indices will contain combinations of vectors with smaller criterion.- Parameters:
inputs (
Array
) – Array of size [batch, dim].criterion (
Array
) – the values according to which to sort the inputs. It has shape [batch, 1].topk (
int
) – The number of outputs to keep.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
object ofPointCloud
, which defines the ground 1D cost function to transport frominputs
to thenum_targets
target values;epsilon
regularization parameter. Remainingkwargs
are passed on to parameterize theSinkhorn
solver.
- Return type:
- Returns:
An Array of size [batch | topk, dim].