ott.tools package#

The tools package contains high level functions that build on outputs produced by core functions. They can be used to compute Sinkhorn divergences [Séjourné et al., 2019], instantiate transport matrices, provide differentiable approximations to ranks and quantile functions [Cuturi et al., 2019], etc.

Segmented Sinkhorn#

segment_sinkhorn.segment_sinkhorn(x, y[, ...])

Compute reg_ot_cost between subsets of vectors described in x & y.

Sinkhorn Divergence#

sinkhorn_divergence.sinkhorn_divergence(...)

Compute Sinkhorn divergence defined by a geometry, weights, parameters.

sinkhorn_divergence.segment_sinkhorn_divergence(x, y)

Compute sinkhorn divergence between subsets of vectors given in x & y.

Soft Sorting Algorithms#

soft_sort.quantile(inputs[, axis, level, weight])

Apply the soft quantile operator on the input tensor.

soft_sort.quantile_normalization(inputs, targets)

Renormalize inputs so that its quantiles match those of targets/weights.

soft_sort.quantize(inputs[, num_levels, axis])

Soft quantizes an input according using num_levels values along axis.

soft_sort.ranks(inputs[, axis, num_targets])

Apply the soft trank operator on input tensor.

soft_sort.sort(inputs[, axis, topk, num_targets])

Apply the soft sort operator on a given axis of the input.

soft_sort.sort_with(inputs, criterion[, topk])

Sort a multidimensional array according to a real valued criterion.

Clustering#

k_means.k_means(geom, k[, weights, init, ...])

K-means clustering using Lloyd's algorithm [Lloyd, 1982].

k_means.KMeansOutput(centroids, assignment, ...)

Output of the k_means() algorithm.

ott.tools.gaussian_mixture package#

Gaussian Mixtures#

gaussian.Gaussian(loc, scale)

PyTree for a normal distribution.

gaussian_mixture.GaussianMixture(loc, ...)

Pytree for a Gaussian Mixture model.

gaussian_mixture_pair.GaussianMixturePair(...)

Pytree for a coupled pair of Gaussian mixture models.

fit_gmm.initialize(key, points, ...[, ...])

Initialize a GMM via K-means++ with retries on failure.

fit_gmm.fit_model_em(gmm, points, ...[, ...])

Fit a GMM using the EM algorithm.

fit_gmm_pair.get_fit_model_em_fn(...[, ...])

Get a function that performs penalized EM.