ott.geometry package#

OTT ground geometries: Classes and cost functions to instantiate them.

This package implements several classes to define a geometry, arguably the most influential ingredient of optimal transport problem. In its full generality, a Geometry defines source points (input measure), target points (target measure) and a ground cost function (resp. a positive kernel function) that quantifies how expensive (resp. easy) it is to displace a unit of mass from any of the input points to the target points.

The geometry package proposes a few simple geometries. The simplest of all would be that for which input and target points coincide, and the geometry between them simplifies to a symmetric cost or kernel matrix. In the very particular case where these points happen to lie on grid (a cartesian product in full generality, e.g. 2 or 3D grids), the Grid geometry will prove useful.

For more general settings where input/target points do not coincide, one can alternatively instantiate a Geometry through a rectangular cost matrix.

However, it is often preferable in applications to define ground costs “symbolically”, by listing instead points in the input/target point clouds, to specify directly a cost function between them. Such functions should follow the CostFn class description. We provide a few standard cost functions that are meaningful in an OT context, notably the (unbalanced, regularized) Bures distances between Gaussians [Janati et al., 2020]. That cost can be used for instance to compute a distance between Gaussian mixtures, as proposed in [Chen et al., 2019] and revisited in [Delon and Desolneux, 2020].

To be useful with Sinkhorn solvers, Geometries typically need to provide an epsilon regularization parameter. We propose either to set that value once for all, or implement an annealing Epsilon scheduler.


geometry.Geometry([cost_matrix, ...])

Base class to define ground costs/kernels used in optimal transport.

pointcloud.PointCloud(x[, y, cost_fn, ...])

Defines geometry for 2 point clouds (possibly 1 vs itself) using CostFn.

grid.Grid([x, grid_size, cost_fns, num_a, ...])

Class describing the geometry of points taken in a cartestian product.

graph.Graph([graph, laplacian, t, n_steps, ...])

Graph distance approximation using heat kernel [Crane et al., 2013, Heitz et al., 2021].

low_rank.LRCGeometry(cost_1, cost_2[, bias, ...])

Low-rank Cost Geometry defined by two factors.

epsilon_scheduler.Epsilon([target, ...])

Scheduler class for the regularization parameter epsilon.

Cost Functions#


A generic cost function, taking two vectors as input.


Squared Euclidean distance CostFn.


Cosine distance CostFn.

costs.Bures(dimension, **kwargs)

Bures distance between a pair of (mean, cov matrix) raveled as vectors.

costs.UnbalancedBures(dimension[, gamma, sigma])

Regularized/unbalanced Bures dist between two triplets of (mass,mean,cov).