, axis=- 1, num_targets=None, **kwargs)[source]#

Apply the soft trank operator on input tensor.

  • inputs (ndarray) – a jnp.ndarray<float> of any shape.

  • axis (int) – the axis on which to apply the soft ranks operator.

  • num_targets (Optional[int]) – num_targets defines the number of targets used to compute a composite ranks for each value in inputs: that soft rank will be a convex combination of values in [0,…,``(num_targets-2)/num_targets``,1] specified by the optimal transport between values in inputs towards those values. If not specified, num_targets is set by default to be the size of the slices of the input that are sorted.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn, used in PointCloud, that defines the ground cost function to transport from inputs to the num_targets target values (squared Euclidean distance by default, see for more details); epsilon values as well as other parameters to shape the sinkhorn algorithm.

Return type



A jnp.ndarray<float> of the same shape as inputs, with the ranks.