ott.tools.soft_sort.ranks#

ott.tools.soft_sort.ranks(inputs, axis=-1, num_targets=None, **kwargs)[source]#

Apply the soft rank operator on input tensor.

Parameters:
  • inputs (Array) – a jnp.ndarray<float> of any shape.

  • axis (int) – the axis on which to apply the soft ranks operator.

  • num_targets (Optional[int]) – num_targets defines the number of targets used to compute a composite ranks for each value in inputs: that soft rank will be a convex combination of values in [0,…,``(num_targets-2)/num_targets``,1] specified by the optimal transport between values in inputs towards those values. If not specified, num_targets is set by default to be the size of the slices of the input that are sorted.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn, used in PointCloud, that defines the ground cost function to transport from inputs to the num_targets target values (squared Euclidean distance by default, see pointcloud.py for more details); epsilon values as well as other parameters to shape the sinkhorn algorithm.

Return type:

Array

Returns:

A jnp.ndarray<float> of the same shape as inputs, with the ranks.