ott.tools.soft_sort.quantile#
- ott.tools.soft_sort.quantile(inputs, axis=-1, level=0.5, weight=0.05, **kwargs)[source]#
Apply the soft quantile operator on the input tensor.
For instance:
x = jax.random.uniform(rng, (1000,)) q = quantile(x, level=0.5, weight=0.01)
Then q will be computed as a mean over the 10 median points of x. Therefore, there is a trade-off between accuracy and gradient.
- Parameters:
inputs (
Array
) – a jnp.ndarray<float> of any shape.axis (
int
) – the axis on which to apply the operator.level (
float
) – the value of the quantile level to be computed. 0.5 for median.weight (
float
) – the weight of the quantile in the transport problem.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
, used inPointCloud
, that defines the ground cost function to transport frominputs
to thenum_targets
target values (squared Euclidean distance by default, seepointcloud.py
for more details);epsilon
values as well as other parameters to shape thesinkhorn
algorithm.
- Return type:
- Returns:
A jnp.ndarray, which has the same shape as the input, except on the give axis on which the dimension is 1.