ott.tools.soft_sort.quantile#
- ott.tools.soft_sort.quantile(inputs, q, axis=-1, weight=None, **kwargs)[source]#
Apply the soft quantiles operator on the input tensor.
For instance:
x = jax.random.uniform(rng, (100,)) x_quantiles = quantile(x, q=jnp.array([0.2, 0.8]))
x_quantiles
will hold an approximation to the 20 and 80 percentiles inx
, computed as a convex combination (a weighted mean, with weights summing to 1) of all values inx
(and not, as for standard quantiles, the valuesx_sorted[20]
andx_sorted[80]
ifx_sorted=jnp.sort(x)
). These values offer a trade-off between accuracy (closeness to the true percentiles) and gradient (the Jacobian ofx_quantiles
w.r.tx
will impact all values listed inx
, not just those indexed at 20 and 80).The non-differentiable version is given by
jax.numpy.quantile()
, e.g.x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8]))
- Parameters:
inputs (
Array
) – an Array of any shape.q (
Union
[float
,Array
,None
]) – values of the quantile level to be computed, e.g. [0.5] for median. These values should all lie in \([0,1]\).axis (
Union
[int
,Tuple
[int
,...
]]) – the axis on which to apply the operator.weight (
Union
[float
,Array
,None
]) – the weight assigned to each quantile target value in the OT problem. This weight should be small, typically of the order of1/n
, wheren
is the size ofx
. Note: Since the size ofq
timesweight
must be strictly smaller than1
, in order to leave enough mass to set other target values in the transport problem, the algorithm might ensure this by setting, when needed, a lower value.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
object ofPointCloud
, which defines the ground 1D cost function to transport frominputs
to thenum_targets
target values;epsilon
regularization parameter. Remainingkwargs
are passed on to parameterize theSinkhorn
solver.
- Return type:
- Returns:
An Array, which has the same shape as
inputs
, except on theaxis
that is passed, which has sizeq.shape[0]
, to collect soft-quantile values.