ott.geometry#

This package implements several classes to define a geometry, arguably the most influential ingredient of optimal transport problem. In its full generality, a Geometry defines source points (input measure), target points (target measure) and a ground cost function (resp. a positive kernel function) that quantifies how expensive (resp. easy) it is to displace a unit of mass from any of the input points to the target points.

The geometry package proposes a few simple geometries. The simplest of all would be that for which input and target points coincide, and the geometry between them simplifies to a symmetric cost or kernel matrix. In the very particular case where these points happen to lie on grid (a Cartesian product in full generality, e.g., 2- or-3-dimensional grids), the Grid geometry will prove useful.

For more general settings where input/target points do not coincide, one can alternatively instantiate a Geometry through a rectangular cost matrix.

However, it is often preferable in applications to define ground costs “symbolically”, by listing instead points in the input/target point clouds, to specify directly a cost function between them. Such functions should follow the CostFn class description. We provide a few standard cost functions that are meaningful in an OT context, notably the (unbalanced, regularized) Bures distances between Gaussians [Janati et al., 2020]. That cost can be used for instance to compute a distance between Gaussian mixtures, as proposed in [Chen et al., 2019] and revisited in [Delon and Desolneux, 2020].

To be useful with Sinkhorn solvers, Geometries typically need to provide an epsilon regularization parameter. We propose either to set that value once for all, or implement an annealing Epsilon scheduler.

Geometries#

geometry.Geometry([cost_matrix, ...])

Base class to define ground costs/kernels used in optimal transport.

pointcloud.PointCloud(x[, y, cost_fn, ...])

Defines geometry for 2 point clouds (possibly 1 vs itself).

grid.Grid([x, grid_size, cost_fns, num_a, ...])

Class describing the geometry of points taken in a Cartesian product.

graph.Graph(laplacian[, t, n_steps, ...])

Graph distance approximation using heat kernel [Crane et al., 2013, Heitz et al., 2021].

geodesic.Geodesic(scaled_laplacian, eigval, ...)

Graph distance approximation using heat kernel [Huguet et al., 2023].

low_rank.LRCGeometry(cost_1, cost_2[, bias, ...])

Geometry whose cost is defined by product of two low-rank matrices.

low_rank.LRKGeometry(k1, k2[, epsilon])

Low-rank kernel geometry.

semidiscrete_pointcloud.SemidiscretePointCloud(...)

Semidiscrete point cloud geometry.

epsilon_scheduler.Epsilon(target[, init, decay])

Scheduler class for the regularization parameter epsilon.

epsilon_scheduler.DEFAULT_EPSILON_SCALE

Scaling applied to statistic (mean/std) of cost to compute default epsilon.

Cost Functions#

costs.CostFn()

Base class for all costs.

costs.TICost()

Base class for translation invariant (TI) costs.

costs.SqPNorm(p)

Squared p-norm of the difference of two vectors.

costs.PNormP(p)

\(p\)-norm to the power \(p\) and divided by \(p\).

costs.SqEuclidean()

Squared Euclidean distance.

costs.NegDotProduct()

Negative Dot-product cost.

costs.RegTICost(regularizer[, lam, rho])

Regularized translation-invariant cost.

costs.Euclidean()

Euclidean distance.

costs.EuclideanP(p)

\(p\)-power of Euclidean norm.

costs.Cosine([ridge])

Cosine distance cost function.

costs.Arccos(n[, ridge])

Arc-cosine cost function [Cho and Saul, 2009].

costs.Bures(dimension[, sqrtm_kw])

Bures distance between a pair of (mean, covariance matrix).

costs.UnbalancedBures(dimension, *[, sigma, ...])

Unbalanced Bures distance between two triplets of (mass, mean, cov).

costs.SoftDTW(gamma[, ground_cost, debiased])

Soft dynamic time warping (DTW) cost [Cuturi and Blondel, 2017].

distrib_costs.UnivariateWasserstein(solve_fn)

1D Wasserstein cost for two 1D distributions.

Regularizers#

regularizers.ProximalOperator()

Proximal operator base class.

regularizers.PostComposition(f[, alpha, b])

Postcomposition operator \(\alpha f\left(x\right) + b\).

regularizers.Regularization(f[, a, rho])

Regularization operator \(f\left(x\right) + \frac{\rho}{2}\|x - a\|_2^2\).

regularizers.Orthogonal(f, A[, b, nu])

Orthogonal operator \(f\left( Ax \right) + b\).

regularizers.Quadratic([A, b, ...])

Quadratic operator \(\frac{1}{2} \left<x, Q x\right> + b\).

regularizers.L1()

L1-norm regularizer \(\ell_1\).

regularizers.SqL2([A])

Squared L2-norm regularizer \(\ell_2^2\).

regularizers.STVS([gamma])

Soft thresholding operator with vanishing shrinkage regularizer [Schreck et al., 2016].

regularizers.SqKOverlap(k)

Squared k-overlap norm regularizer [Argyriou et al., 2012].

Utilities#

segment.segment_point_cloud(x[, a, ...])

Segment and pad as needed the entries of a point cloud.