Grid Geometry#

The Grid geometry was designed having in mind the many applications that use the so-called Eulerian description of probability measures. In such applications, probability measures are seen as histograms supported on a \(d\)-dimensional Cartesian grid, and not as point clouds.

A Grid geometry instantiates a Cartesian product of \(d\) lists of values, each list of index \(i\) being of size \(n_i\). That Cartesian product has a total number of \(N:=\prod_i n_i\) possible locations in \(\mathbb{R}^d\).

A Grid geometry also assumes that the ground cost between points in the grid is separable: For two points \(x, y\) in that grid, the cost must be of the form \(\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)\) where \(\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\). As a result, a \(d\)-dimensional Grid expects a tuple of up to \(d\) A CostFn objects, each describing a cost between two real-values.

The advantage of using Grid over PointCloud is that fundamental operations, such as applying the \(N\times N\) square cost matrix of all pairwise distances between the \(N\) points in the grid, as well as its kernel, can be efficiently carried out in \(O(N^{(1+1/d)})\) operations, with a similar memory footprint, rather than instantiating naively those matrices as \(N^2\) blocks.

import sys

if "google.colab" in sys.modules:
    %pip install -q git+https://github.com/ott-jax/ott@main
import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, grid, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

Create Grid with the x argument#

In this example, the argument x is a list of \(d=3\) vectors \(x_1, x_2, x_3\), of varying sizes \(\{n_1, n_2, n_3\}\), that describe the locations of the grid along each dimension. The resulting grid is the Cartesian product of these vectors (seen each as a list of values), namely \(\{u\in x_1\}\times \{u\in x_2,\} \times \{u\in x_3\}\). Assuming each vector is formed with distinct coordinates, that Cartesian product holds \(N = n_1 n_2 n_3 = 5 \times 6 \times 7 = 210\) distinct points in the example below. a and b are two histograms on that grid, namely probability vectors of size \(N\). Note that, to showcase the versatility of the Grid API, the grid is here is irregularly spaced, since locations along each dimension are random.

rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 5)

grid_size = (5, 6, 7)
x = [
    jax.random.uniform(keys[0], (grid_size[0],)),
    jax.random.uniform(keys[1], (grid_size[1],)),
    jax.random.uniform(keys[2], (grid_size[2],)),
]

We have now all ingredients to create a geom object that will describe that grid. Since we do not specify any cost function for each dimension, denoted as \(\text{cost}_i\) in the formula above, our instantiation will default to SqEuclidean (between real numbers) for each dimension. Naturally this is mathematically equivalent to running computations with the a point cloud object, instantiated with a usual squared-Euclidean distance between vectors in \(\mathbb{R}^3\). We will get back to that approach later in this tutorial.

geom = grid.Grid(x=x)

We can now generate two histograms a and b on that grid. These have total size equal to \(N\), and are unfolded to have a unified vector API for probability weights. They will, however, be reshaped within computations as tensors of grid_size shape.

a = jax.random.uniform(keys[3], grid_size)
b = jax.random.uniform(keys[4], grid_size)
a = a.ravel() / jnp.sum(a)  # Normalize to have unit total mass.
b = b.ravel() / jnp.sum(b)  # "

We now solve the OT problem between weights a to b, by running a Sinkhorn solver, to output the regularized optimal transport cost. The example below illustrates how ott delegates low-level geometric computations to the geom objects, and never to the Sinkhorn solver.

prob = linear_problem.LinearProblem(geom, a=a, b=b)
solver = sinkhorn.Sinkhorn()
out = solver(prob)

print(f"Regularized OT cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.08210793137550354

Create Grid with the grid_size argument#

When only the grid_size shape tensor is specified, the grid is assumed to be regular, and locations along each axis are assumed to be of the form \(j/(n_i-1)\) for \(0\leq j\leq n_i-1\). This will therefore result in a simple grid in the 3-D hypercube \([0, 1]^3\). As expected, even when keeping the same histograms a and b, the OT cost is different, since we have shifted points.

geom = grid.Grid(grid_size=grid_size, epsilon=0.1)

# We recycle the same probability vectors
prob = linear_problem.LinearProblem(geom, a=a, b=b)

out = solver(prob)

print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.28149110078811646

Different cost_fn for each dimension#

In the examples above, we have assumed that the cost function \(\text{cost}_i\) was the squared Euclidean distance. To illustrate how a different cost function can be chosen for each dimension, we implement an exotic custom cost function between real numbers.

@jax.tree_util.register_pytree_node_class
class MyCost(costs.CostFn):
    """An unusual cost function."""

    def norm(self, x):
        return jnp.sum(x**3 + jnp.cos(x) ** 2, axis=-1)

    def pairwise(self, x, y):
        return -jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2

Using the same grid size, we redefine Grid with these new cost functions, and recompute a regularized optimal transport cost.

cost_fns = [MyCost(), costs.SqEuclidean(), MyCost()]  # 1 for each dimension.
geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
out = solver(prob)

print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 1.2038968801498413

Compare runtime between using Grid and PointCloud#

Why use a Grid geometry instead of a PointCloud geometry, defined with \(N\) points?. In addition to convenience, the main advantage of Grid geometries is computational.

Indeed, the Sinkhorn algorithm applies a kernel operator, derived directly from the geometry, at each of its steps. Grid geometries apply that kernel in \(O(N^{(1+1/d)})\) operations, whereas PointCloud require a \(O(N^2)\) complexity, where \(N\) is the total number of points in the grid and \(d\) its dimension. Note that these two approaches are numerically equivalent, it’s just that the former is more efficient than the latter.

You can see this by yourself in the example below. We instantiate two grid geometries that are mathematically equivalent (describing the same points), and show that running Sinkhorn iterations with a Grid is 180 times faster compared to a naive PointCloud.

grid_size = (37, 29, 43)

rng = jax.random.PRNGKey(2)
keys = jax.random.split(rng, 2)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

print("Total size of grid: ", jnp.product(jnp.array(grid_size)))
Total size of grid:  46139
# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size)
prob_grid = linear_problem.LinearProblem(geometry_grid, a=a, b=b)

%timeit solver(prob_grid).reg_ot_cost.block_until_ready()
out_grid = solver(prob_grid)
print(
    f"Regularized optimal transport cost using Grid = {out_grid.reg_ot_cost}\n"
)
2.03 s ± 17.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Regularized optimal transport cost using Grid = 0.10972004383802414
# List all 3D points in cartesian product.
x, y, z = np.mgrid[0 : grid_size[0], 0 : grid_size[1], 0 : grid_size[2]]
xyz = jnp.stack(
    [
        jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
        jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
        jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
    ]
).transpose()
# Instantiates PointCloud with `batch_size` argument.
# Computations require being run in batches, otherwise memory would
# overflow. This is achieved by setting `batch_size` to 1024.
geometry_pointcloud = pointcloud.PointCloud(xyz, xyz, batch_size=1024)
prob_pointcloud = linear_problem.LinearProblem(geometry_pointcloud, a=a, b=b)
%timeit solver(prob_pointcloud).reg_ot_cost.block_until_ready()
out_pointcloud = solver(prob_pointcloud)
print(
    f"Regularized optimal transport cost using PointCloud = {out_pointcloud.reg_ot_cost}"
)
6min 5s ± 5.8 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Regularized optimal transport cost using PointCloud = 0.10972030460834503