ott.core package#

OTT core libraries: the engine behind most computations happening in OTT.

The core package contains definitions of various OT problems, starting from the most simple, the linear OT problem, to more advanced problems such as quadratic, or involving multiple measures, the barycenter problem. We follow with the classic sinkhorn routine (essentially a wrapper for the Sinkhorn solver class) [Cuturi, 2013, Séjourné et al., 2019]. We also provide an analogous low-rank Sinkhorn solver [Scetbon et al., 2021] to handle very large instances. Both are used within our Wasserstein barycenter solvers [Benamou et al., 2015, Janati et al., 2020], as well as our Gromov-Wasserstein solver [Mémoli, 2011, Scetbon et al., 2022]. We also provide an implementation of input convex neural networks [Amos et al., 2017], a NN that can be used to estimate OT [Makkuva et al., 2020].

OT Problems#

linear_problems.LinearProblem(geom[, a, b, ...])

Linear regularized OT problem.

quad_problems.QuadraticProblem(geom_xx, geom_yy)

Quadratic regularized OT problem.

bar_problems.BarycenterProblem(y[, b, ...])

Wasserstein barycenter problem [Cuturi and Doucet, 2014].

bar_problems.GWBarycenterProblem([y, b, ...])

(Fused) Gromov-Wasserstein barycenter problem [Peyré et al., 2016, Titouan et al., 2019].

Sinkhorn#

sinkhorn.sinkhorn(geom[, a, b, tau_a, ...])

Solve regularized OT problem using Sinkhorn iterations.

sinkhorn.Sinkhorn([lse_mode, threshold, ...])

A Sinkhorn solver for linear reg-OT problem.

sinkhorn.SinkhornOutput([f, g, errors, ...])

Implements the problems.Transport interface, for a Sinkhorn solution.

Sinkhorn Dual Initializers#

initializers.DefaultInitializer()

Default initialization of Sinkhorn dual potentials/primal scalings.

initializers.GaussianInitializer()

Gaussian initializer [Thornton and Cuturi, 2022].

initializers.SortingInitializer([...])

Sorting initializer [Thornton and Cuturi, 2022].

Low-Rank Sinkhorn#

sinkhorn_lr.LRSinkhorn(rank[, gamma, ...])

A Low-Rank Sinkhorn solver for linear reg-OT problems.

sinkhorn_lr.LRSinkhornOutput(q, r, g, costs, ...)

Implement the problems.Transport interface, for a LR Sinkhorn solution.

Low-Rank Sinkhorn Initializers#

initializers_lr.RandomInitializer(rank, **kwargs)

Low-rank Sinkhorn factorization using random factors.

initializers_lr.Rank2Initializer(rank, **kwargs)

Low-rank Sinkhorn factorization using rank-2 factors [Scetbon et al., 2021].

initializers_lr.KMeansInitializer(rank[, ...])

K-means initializer for low-rank Sinkhorn [Scetbon and Cuturi, 2022].

initializers_lr.GeneralizedKMeansInitializer(rank)

Generalized k-means initializer [Scetbon and Cuturi, 2022].

Quadratic Initializers#

quad_initializers.QuadraticInitializer(**kwargs)

Initialize a linear problem locally around a naive initializer ab'.

quad_initializers.LRQuadraticInitializer(...)

Wrapper that wraps low-rank Sinkhorn initializers.

Barycenters (Entropic and LR)#

discrete_barycenter.discrete_barycenter(geom, a)

Compute discrete barycenter [Janati et al., 2020].

continuous_barycenter.WassersteinBarycenter([...])

A Continuous Wasserstein barycenter solver, built on generic template.

continuous_barycenter.BarycenterState([...])

Holds the state of the Wasserstein barycenter solver.

gw_barycenter.GromovWassersteinBarycenter([...])

Gromov-Wasserstein barycenter solver of the GWBarycenterProblem.

gw_barycenter.GWBarycenterState([cost, x, ...])

Holds the state of the GWBarycenterProblem.

Gromov-Wasserstein (Entropic and LR)#

gromov_wasserstein.gromov_wasserstein(...[, ...])

Solve a Gromov Wasserstein problem.

gromov_wasserstein.GromovWasserstein(*args)

Gromov-Wasserstein solver.

gromov_wasserstein.GWOutput([costs, ...])

Holds the output of the Gromov-Wasserstein solver.

Neural Potentials#

icnn.ICNN(dim_hidden[, init_std, init_fn, ...])

Input convex neural network (ICNN) architecture with initialization.

neuraldual.NeuralDualSolver(input_dim[, ...])

Solver of the ICNN-based Kantorovich dual.

neuraldual.NeuralDual(state_f, state_g)

Neural Kantorovich dual.

Padding Utilities#

segment.segment_point_cloud(x[, a, ...])

Segment and pad as needed the entries of a point cloud.