Grid Geometry#

In this tutorial, we will cover how to instantiate and use Grid geometry.

Grid is a geometry that is useful when the probability measures are supported on a \(d\)-dimensional Cartesian grid, i.e., a Cartesian product of \(d\) lists of values, each list \(i\) being of size \(n_i\). The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in \(\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)\) where \(\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\).

The advantage of using Grid over PointCloud for such cases is that the computational cost is \(O(N^{(1+1/d)})\) instead of \(O(N^2)\) where \(N\) is the total number of points in the grid.

import sys

if "google.colab" in sys.modules:
    !pip install -q git+https://github.com/ott-jax/ott@main
import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, grid, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

Create Grid with the x argument#

In this example, the argument x is a list of \(3\) vectors, of varying sizes \(\{n_1, n_2, n_3\}\), that describe the locations of the grid. The resulting grid is the Cartesian product of these vectors. a and b are two histograms in a grid of size \(5 \times 6 \times 7\) that lies in the 3-dimensional hypercube.

rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 5)

grid_size = (5, 6, 7)
x = [
    jax.random.uniform(keys[0], (grid_size[0],)),
    jax.random.uniform(keys[1], (grid_size[1],)),
    jax.random.uniform(keys[2], (grid_size[2],)),
]
a = jax.random.uniform(keys[3], grid_size)
b = jax.random.uniform(keys[4], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

Instantiate Grid and calculate the regularized optimal transport cost.

geom = grid.Grid(x=x, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)

solver = sinkhorn.Sinkhorn()
out = solver(prob)

print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.20520979166030884

Create Grid with the grid_size argument#

In this example, the grid is described as points regularly sampled in \([0, 1]\). a and b are two histograms in a grid of size \(5 \times 6 \times 7\) that lies in the 3-dimensional hypercube \([0, 1]^3\).

rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)

grid_size = (5, 6, 7)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

Instantiate Grid and calculate the regularized optimal transport cost.

geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)

out = solver(prob)

print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.281633585691452

Varies the cost function in each dimension#

Instead of the squared Euclidean distance, we will use a squared Mahalanobis distance, where the covariance matrix is diagonal. This example illustrates the possibility of choosing a cost function for each dimension.

rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)

grid_size = (5, 6)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

We want to use as covariance matrix for the Mahalanobis distance the diagonal \(2 \times 2\) matrix, with \([1/2, 1]\) as diagonal. We create an additional cost function.

@jax.tree_util.register_pytree_node_class
class SqEuclideanTimes2(costs.CostFn):
    """The cost function corresponding to the squared SqEuclidean distance times 2."""

    def norm(self, x):
        return jnp.sum(x**2, axis=-1) * 2

    def pairwise(self, x, y):
        return -2 * jnp.sum(x * y) * 2


cost_fns = [SqEuclideanTimes2(), costs.SqEuclidean()]

Instantiate Grid and calculate the regularized optimal transport cost.

geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
out = solver(prob)

print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.22414201498031616

Compare runtime between using Grid and PointCloud#

The squared euclidean distance is an example of separable distance for which it is possible to use Grid instead of PointCloud. In this case, using Grid over PointCloud as geometry in the context of regularized optimal transport presents a computational advantage, as the computational cost of applying a kernel in Sinkhorn steps is of the order of \(O(N^{(1+1/d)})\) instead of the naive \(O(N^2)\) complexity, where \(N\) is the total number of points in the grid and \(d\), its dimension.

In this example, we can see that for the same grid size and points, the computational runtime of Sinkhorn with Grid is smaller than with PointCloud.

epsilon = 0.1
grid_size = (50, 50, 50)

rng = jax.random.PRNGKey(2)
keys = jax.random.split(rng, 2)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
prob_grid = linear_problem.LinearProblem(geometry_grid, a=a, b=b)

x, y, z = np.mgrid[0 : grid_size[0], 0 : grid_size[1], 0 : grid_size[2]]
xyz = jnp.stack(
    [
        jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
        jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
        jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
    ]
).transpose()
# Instantiates PointCloud with `batch_size` argument
geometry_pointcloud = pointcloud.PointCloud(
    xyz, xyz, epsilon=epsilon, batch_size=1024
)
prob_pointcloud = linear_problem.LinearProblem(geometry_pointcloud, a=a, b=b)

%timeit solver(prob_grid).reg_ot_cost.block_until_ready()
out_grid = solver(prob_grid)
print(
    f"Regularized optimal transport cost using Grid = {out_grid.reg_ot_cost}\n"
)

%timeit solver(prob_pointcloud).reg_ot_cost.block_until_ready()
out_pointcloud = solver(prob_pointcloud)
print(
    f"Regularized optimal transport cost using Pointcloud = {out_pointcloud.reg_ot_cost}"
)
593 ms ± 2.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Regularized optimal transport cost using Grid = 0.24500826001167297