, axis=-1, level=0.5, weight=0.05, **kwargs)[source]#

Apply the soft quantile operator on the input tensor.

For instance:

x = jax.random.uniform(rng, (1000,)) q = quantile(x, level=0.5, weight=0.01)

Then q will be computed as a mean over the 10 median points of x. Therefore, there is a trade-off between accuracy and gradient.

  • inputs (Array) – a jnp.ndarray<float> of any shape.

  • axis (int) – the axis on which to apply the operator.

  • level (float) – the value of the quantile level to be computed. 0.5 for median.

  • weight (float) – the weight of the quantile in the transport problem.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn, used in PointCloud, that defines the ground cost function to transport from inputs to the num_targets target values (squared Euclidean distance by default, see for more details); epsilon values as well as other parameters to shape the sinkhorn algorithm.

Return type:



A jnp.ndarray, which has the same shape as the input, except on the give axis on which the dimension is 1.