, q, axis=-1, weight=None, **kwargs)[source]#

Apply the soft quantiles operator on the input tensor.

For instance:

x = jax.random.uniform(rng, (100,))
x_quantiles = quantile(x, q=jnp.array([0.2, 0.8]))

x_quantiles will hold an approximation to the 20 and 80 percentiles in x, computed as a convex combination (a weighted mean, with weights summing to 1) of all values in x (and not, as for standard quantiles, the values x_sorted[20] and x_sorted[80] if x_sorted=jnp.sort(x)). These values offer a trade-off between accuracy (closeness to the true percentiles) and gradient (the Jacobian of x_quantiles w.r.t x will impact all values listed in x, not just those indexed at 20 and 80).

The non-differentiable version is given by jax.numpy.quantile(), e.g.

x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8]))
  • inputs (Array) – an Array of any shape.

  • q (Union[float, Array, None]) – values of the quantile level to be computed, e.g. [0.5] for median. These values should all lie in \([0,1]\).

  • axis (Union[int, Tuple[int, ...]]) – the axis on which to apply the operator.

  • weight (Union[float, Array, None]) – the weight assigned to each quantile target value in the OT problem. This weight should be small, typically of the order of 1/n, where n is the size of x. Note: Since the size of q times weight must be strictly smaller than 1, in order to leave enough mass to set other target values in the transport problem, the algorithm might ensure this by setting, when needed, a lower value.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn object of PointCloud, which defines the ground 1D cost function to transport from inputs to the num_targets target values; epsilon regularization parameter. Remaining kwargs are passed on to parameterize the Sinkhorn solver.

Return type:



An Array, which has the same shape as inputs, except on the axis that is passed, which has size q.shape[0], to collect soft-quantile values.