The tools package contains high level functions that build on outputs produced by core functions. They can be used to compute Sinkhorn divergences [Séjourné et al., 2019], instantiate transport matrices, provide differentiable approximations to ranks and quantile functions [Cuturi et al., 2019], etc.

Segmented Sinkhorn#

segment_sinkhorn.segment_sinkhorn(x, y[, ...])

Compute regularized OT cost between subsets of vectors in x and y.


plot.Plot([fig, ax, cost_threshold, scale, ...])

Plot an optimal transport map between two point clouds.

Sinkhorn Divergence#


Compute Sinkhorn divergence defined by a geometry, weights, parameters.

sinkhorn_divergence.segment_sinkhorn_divergence(x, y)

Compute Sinkhorn divergence between subsets of vectors in x and y.

Soft Sorting Algorithms#

soft_sort.quantile(inputs[, axis, level, weight])

Apply the soft quantile operator on the input tensor.

soft_sort.quantile_normalization(inputs, targets)

Renormalize inputs so that its quantiles match those of targets/weights.

soft_sort.quantize(inputs[, num_levels, axis])

Soft quantizes an input according using num_levels values along axis.

soft_sort.ranks(inputs[, axis, num_targets])

Apply the soft rank operator on input tensor.

soft_sort.sort(inputs[, axis, topk, num_targets])

Apply the soft sort operator on a given axis of the input.

soft_sort.sort_with(inputs, criterion[, topk])

Sort a multidimensional array according to a real valued criterion.


k_means.k_means(geom, k[, weights, init, ...])

K-means clustering using Lloyd's algorithm [Lloyd, 1982].

k_means.KMeansOutput(centroids, assignment, ...)

Output of the k_means() algorithm. package#

Gaussian Mixtures#

gaussian.Gaussian(loc, scale)

Normal distribution.

gaussian_mixture.GaussianMixture(loc, ...)

Gaussian Mixture model.


Coupled pair of Gaussian mixture models.

fit_gmm.initialize(rng, points, ...[, ...])

Initialize a GMM via K-means++ with retries on failure.

fit_gmm.fit_model_em(gmm, points, ...[, ...])

Fit a GMM using the EM algorithm.

fit_gmm_pair.get_fit_model_em_fn(...[, ...])

Get a function that performs penalized EM.