The tools package contains high level functions that build on outputs produced by core functions. They can be used to compute Sinkhorn divergences [Séjourné et al., 2019], instantiate transport matrices, provide differentiable approximations to ranks and quantile functions [Cuturi et al., 2019], etc.

Segmented Sinkhorn#

segment_sinkhorn.segment_sinkhorn(x, y[, ...])

Compute regularized OT cost between subsets of vectors in x and y.


plot.Plot([fig, ax, threshold, scale, ...])

Plot an optimal transport map between two point clouds.

Sinkhorn Divergence#


Compute Sinkhorn divergence defined by a geometry, weights, parameters.

sinkhorn_divergence.segment_sinkhorn_divergence(x, y)

Compute Sinkhorn divergence between subsets of vectors in x and y.

Soft Sorting Algorithms#


Returns multivariate CDF and quantile maps, given input samples.

soft_sort.quantile(inputs, q[, axis, weight])

Apply the soft quantiles operator on the input tensor.

soft_sort.quantile_normalization(inputs, targets)

Re-normalize inputs so that its quantiles match those of targets/weights.

soft_sort.quantize(inputs[, num_levels, axis])

Soft quantizes an input according using num_levels values along axis.

soft_sort.ranks(inputs[, axis, num_targets, ...])

Apply the soft rank operator on input tensor.

soft_sort.sort(inputs[, axis, topk, num_targets])

Apply the soft sort operator on a given axis of the input.

soft_sort.sort_with(inputs, criterion[, topk])

Sort a multidimensional array according to a real valued criterion.

soft_sort.topk_mask(inputs[, axis, k])

Soft \(\text{top-}k\) selection mask.


k_means.k_means(geom, k[, weights, init, ...])

K-means clustering using Lloyd's algorithm [Lloyd, 1982].

k_means.KMeansOutput(centroids, assignment, ...)

Output of the k_means() algorithm. package#

This package implements various tools to manipulate Gaussian mixtures with a slightly modified Wasserstein geometry: here a Gaussian mixture is no longer strictly regarded as a density \(\mathbb{R}^d\), but instead as a point cloud in the space of Gaussians in \(\mathbb{R}^d\). This viewpoint provides a new approach to compare, and fit Gaussian mixtures, as described for instance in [Delon and Desolneux, 2020] and references therein.

Gaussian Mixtures#

gaussian.Gaussian(loc, scale)

Normal distribution.

gaussian_mixture.GaussianMixture(loc, ...)

Gaussian Mixture model.


Coupled pair of Gaussian mixture models.

fit_gmm.initialize(rng, points, ...[, ...])

Initialize a GMM via K-means++ with retries on failure.

fit_gmm.fit_model_em(gmm, points, ...[, ...])

Fit a GMM using the EM algorithm.

fit_gmm_pair.get_fit_model_em_fn(...[, ...])

Get a function that performs penalized EM.