ott.geometry

ott.geometry#

This package implements several classes to define a geometry, arguably the most influential ingredient of optimal transport problem. In its full generality, a Geometry defines source points (input measure), target points (target measure) and a ground cost function (resp. a positive kernel function) that quantifies how expensive (resp. easy) it is to displace a unit of mass from any of the input points to the target points.

The geometry package proposes a few simple geometries. The simplest of all would be that for which input and target points coincide, and the geometry between them simplifies to a symmetric cost or kernel matrix. In the very particular case where these points happen to lie on grid (a Cartesian product in full generality, e.g., 2- or-3-dimensional grids), the Grid geometry will prove useful.

For more general settings where input/target points do not coincide, one can alternatively instantiate a Geometry through a rectangular cost matrix.

However, it is often preferable in applications to define ground costs “symbolically”, by listing instead points in the input/target point clouds, to specify directly a cost function between them. Such functions should follow the CostFn class description. We provide a few standard cost functions that are meaningful in an OT context, notably the (unbalanced, regularized) Bures distances between Gaussians [Janati et al., 2020]. That cost can be used for instance to compute a distance between Gaussian mixtures, as proposed in [Chen et al., 2019] and revisited in [Delon and Desolneux, 2020].

To be useful with Sinkhorn solvers, Geometries typically need to provide an epsilon regularization parameter. We propose either to set that value once for all, or implement an annealing Epsilon scheduler.

Geometries#

geometry.Geometry([cost_matrix, ...])

Base class to define ground costs/kernels used in optimal transport.

pointcloud.PointCloud(x[, y, cost_fn, ...])

Defines geometry for 2 point clouds (possibly 1 vs itself).

grid.Grid([x, grid_size, cost_fns, num_a, ...])

Class describing the geometry of points taken in a Cartesian product.

graph.Graph(laplacian[, t, n_steps, ...])

Graph distance approximation using heat kernel [Crane et al., 2013, Heitz et al., 2021].

geodesic.Geodesic(scaled_laplacian, eigval, ...)

Graph distance approximation using heat kernel [Huguet et al., 2023].

low_rank.LRCGeometry(cost_1, cost_2[, bias, ...])

Geometry whose cost is defined by product of two low-rank matrices.

low_rank.LRKGeometry(k1, k2[, epsilon])

Low-rank kernel geometry.

epsilon_scheduler.Epsilon([target, ...])

Scheduler class for the regularization parameter epsilon.

Cost Functions#

costs.CostFn()

Base class for all costs.

costs.SqPNorm(p)

Squared p-norm of the difference of two vectors.

costs.PNormP(p)

p-norm to the power p (and divided by p) of the difference of two vectors.

costs.SqEuclidean()

Squared Euclidean distance.

costs.Euclidean()

Euclidean distance.

costs.Cosine([ridge])

Cosine distance cost function.

costs.Arccos(n[, ridge])

Arc-cosine cost function [Cho and Saul, 2009].

costs.Bures(dimension[, sqrtm_kw])

Bures distance between a pair of (mean, covariance matrix).

costs.UnbalancedBures(dimension, *[, sigma, ...])

Unbalanced Bures distance between two triplets of (mass, mean, cov).

costs.ElasticL1([scaling_reg, matrix, ...])

Cost inspired by elastic net [Zou and Hastie, 2005] regularization.

costs.ElasticL2([scaling_reg, matrix, ...])

Cost with L2 regularization.

costs.ElasticSTVS([scaling_reg, matrix, ...])

Cost with soft thresholding operator with vanishing shrinkage (STVS) [Schreck et al., 2016] regularization.

costs.ElasticSqKOverlap(k, *args, **kwargs)

Cost with squared k-overlap norm regularization [Argyriou et al., 2012].

costs.SoftDTW(gamma[, ground_cost, debiased])

Soft dynamic time warping (DTW) cost [Cuturi and Blondel, 2017].

distrib_costs.UnivariateWasserstein([...])

1D Wasserstein cost for two 1D distributions.

Utilities#

segment.segment_point_cloud(x[, a, ...])

Segment and pad as needed the entries of a point cloud.