- ott.tools.soft_sort.quantile(inputs, axis=-1, level=0.5, weight=0.05, **kwargs)[source]#
Apply the soft quantile operator on the input tensor.
x = jax.random.uniform(rng, (1000,)) q = quantile(x, level=0.5, weight=0.01)
Then q will be computed as a mean over the 10 median points of x. Therefore, there is a trade-off between accuracy and gradient.
Array) – a jnp.ndarray<float> of any shape.
int) – the axis on which to apply the operator.
float) – the value of the quantile level to be computed. 0.5 for median.
float) – the weight of the quantile in the transport problem.
Any) – keyword arguments passed on to lower level functions. Of interest to the user are
squashing_fun, which will redistribute the values in
inputsto lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;
cost_fn, used in
PointCloud, that defines the ground cost function to transport from
num_targetstarget values (squared Euclidean distance by default, see
pointcloud.pyfor more details);
epsilonvalues as well as other parameters to shape the
- Return type:
A jnp.ndarray, which has the same shape as the input, except on the give axis on which the dimension is 1.