- ott.tools.soft_sort.quantile(inputs, q, axis=-1, weight=None, **kwargs)#
Apply the soft quantiles operator on the input tensor.
x = jax.random.uniform(rng, (100,)) x_quantiles = quantile(x, q=jnp.array([0.2, 0.8]))
x_quantileswill hold an approximation to the 20 and 80 percentiles in
x, computed as a convex combination (a weighted mean, with weights summing to 1) of all values in
x(and not, as for standard quantiles, the values
x_sorted=jnp.sort(x)). These values offer a trade-off between accuracy (closeness to the true percentiles) and gradient (the Jacobian of
xwill impact all values listed in
x, not just those indexed at 20 and 80).
The non-differentiable version is given by
x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8]))
Array) – an Array of any shape.
None]) – the weight assigned to each quantile target value in the OT problem. This weight should be small, typically of the order of
nis the size of
x. Note: Since the size of
weightmust be strictly smaller than
1, in order to leave enough mass to set other target values in the transport problem, the algorithm might ensure this by setting, when needed, a lower value.
Any) – keyword arguments passed on to lower level functions. Of interest to the user are
squashing_fun, which will redistribute the values in
inputsto lie in \([0,1]\) (sigmoid of whitened values by default) to solve the optimal transport problem;
PointCloud, which defines the ground 1D cost function to transport from
epsilonregularization parameter. Remaining
kwargsare passed on to parameterize the
- Return type:
An Array, which has the same shape as
inputs, except on the
axisthat is passed, which has size
q.shape, to collect soft-quantile values.