, num_levels=10, axis=-1, **kwargs)[source]#

Soft quantizes an input according using num_levels values along axis.

The quantization operator consists in concentrating several values around a few predefined num_levels. The soft quantization operator proposed here does so by carrying out a soft concentration that is differentiable. The inputs values are first soft-sorted, resulting in num_levels values. In a second step, the inputs values are then provided again a composite value that is equal (for each) to a convex combination of those soft-sorted values using the transportation matrix. As the regularization parameter epsilon of regularized optimal transport goes to 0, this operator recovers the expected behavior of quantization, namely each value in inputs is assigned a single level. When using epsilon>0 the behavior is similar but differentiable.

  • inputs (Array) – the inputs as a jnp.ndarray[batch, dim].

  • num_levels (int) – number of levels available to quantize the signal.

  • axis (int) – axis along which quantization is carried out.

  • kwargs (Any) – keyword arguments passed on to lower level functions. Of interest to the user are squashing_fun, which will redistribute the values in inputs to lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem; cost_fn, used in PointCloud, that defines the ground cost function to transport from inputs to the num_targets target values (squared Euclidean distance by default, see for more details); epsilon values as well as other parameters to shape the sinkhorn algorithm.

Return type:



A jnp.ndarray of the same size as inputs.