ott.tools.soft_sort.quantize#
- ott.tools.soft_sort.quantize(inputs, num_levels=10, axis=-1, **kwargs)[source]#
Soft quantizes an input according using num_levels values along axis.
The quantization operator consists in concentrating several values around a few predefined
num_levels
. The soft quantization operator proposed here does so by carrying out a soft concentration that is differentiable. Theinputs
values are first soft-sorted, resulting innum_levels
values. In a second step, theinputs
values are then provided again a composite value that is equal (for each) to a convex combination of those soft-sorted values using the transportation matrix. As the regularization parameterepsilon
of regularized optimal transport goes to 0, this operator recovers the expected behavior of quantization, namely each value ininputs
is assigned a single level. When usingepsilon>0
the behavior is similar but differentiable.- Parameters:
inputs (
Array
) – the inputs as a jnp.ndarray[batch, dim].num_levels (
int
) – number of levels available to quantize the signal.axis (
int
) – axis along which quantization is carried out.kwargs (
Any
) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun
, which will redistribute the values ininputs
to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn
, used inPointCloud
, that defines the ground cost function to transport frominputs
to thenum_targets
target values (squared Euclidean distance by default, seepointcloud.py
for more details);epsilon
values as well as other parameters to shape thesinkhorn
algorithm.
- Return type:
- Returns:
A jnp.ndarray of the same size as
inputs
.