# Grid geometry#

In this tutorial, we cover how to instantiate and use Grid.

Grid is a geometry that is useful when the probability measures are supported on a $$d$$-dimensional cartesian grid, i.e. a cartesian product of $$d$$ lists of values, each list $$i$$ being of size $$n_i$$. The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in $$\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)$$ where $$\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}$$.

The advantage of using Grid over PointCloud for such cases is that the computational cost is $$O(N^{(1+1/d)})$$ instead of $$O(N^2)$$ where $$N$$ is the total number of points in the grid.

[ ]:

import sys

if "google.colab" in sys.modules:
!pip install -q git+https://github.com/ott-jax/ott@main

:

import jax
import jax.numpy as jnp
import numpy as np

from ott.core import sinkhorn
from ott.geometry import costs
from ott.geometry import grid
from ott.geometry import pointcloud


## Uses Grid with the argument x#

In this example, the argument x is a list of $$3$$ vectors, of varying sizes $$\{n_1, n_2, n_3\}$$, that describe the locations of the grid. The resulting grid is the Cartesian product of these vectors. a and b are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube.

:

rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 5)

grid_size = (5, 6, 7)
x = [
jax.random.uniform(keys, (grid_size,)),
jax.random.uniform(keys, (grid_size,)),
jax.random.uniform(keys, (grid_size,)),
]
a = jax.random.uniform(keys, grid_size)
b = jax.random.uniform(keys, grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)


Instantiates Grid and calculates the regularized optimal transport cost.

:

geom = grid.Grid(x=x, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f"Regularised optimal transport cost = {out.reg_ot_cost}")

Regularised optimal transport cost = 0.30520981550216675


## Uses Grid with the argument grid_size#

In this example, the grid is described as points regularly sampled in $$[0, 1]$$. a and b are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube $$[0, 1]^3$$.

:

rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)

grid_size = (5, 6, 7)
a = jax.random.uniform(keys, grid_size)
b = jax.random.uniform(keys, grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)


Instantiates Grid and calculates the regularized optimal transport cost.

:

geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f"Regularised optimal transport cost = {out.reg_ot_cost}")

Regularised optimal transport cost = 0.3816334307193756


## Varies the cost function in each dimension#

Instead of the squared Euclidean distance, we will use a squared Mahalanobis distance, where the covariance matrix is diagonal. This example illustrates the possibility of choosing a cost function for each dimension.

:

rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)

grid_size = (5, 6)
a = jax.random.uniform(keys, grid_size)
b = jax.random.uniform(keys, grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)


We want to use as covariance matrix for the Mahalanobis distance the diagonal 2x2 matrix, with $$[1/2, 1]$$ as diagonal. We create an additional costs.CostFn.

:

@jax.tree_util.register_pytree_node_class
class EuclideanTimes2(costs.CostFn):
"""The cost function corresponding to the squared euclidean distance times 2."""

def norm(self, x):
return jnp.sum(x**2, axis=-1) * 2

def pairwise(self, x, y):
return -2 * jnp.sum(x * y) * 2

cost_fns = [EuclideanTimes2(), costs.Euclidean()]


Instantiates Grid and calculates the regularized optimal transport cost.

:

geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f"Regularised optimal transport cost = {out.reg_ot_cost}")

Regularised optimal transport cost = 0.3241420388221741


## Compares runtime between using Grid and PointCloud#

The squared euclidean distance is an example of separable distance for which it is possible to use Grid instead of PointCloud. In this case, using Grid over PointCloud as geometry in the context of regularised optimal transport presents a computational advantage, as the computational cost of applying a kernel in Sinkhorn steps is of the order of $$O(N^{(1+1/d)})$$ instead of the naive $$O(N^2)$$ complexity, where $$N$$ is the total number of points in the grid and $$d$$ its dimension. In this example, we can see that for the same grid size and points, the computational runtime of sinkhorn with Grid is smaller than with PointCloud.

:

epsilon = 0.1
grid_size = (50, 50, 50)

rng = jax.random.PRNGKey(2)
keys = jax.random.split(rng, 2)
a = jax.random.uniform(keys, grid_size)
b = jax.random.uniform(keys, grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)

x, y, z = np.mgrid[0 : grid_size, 0 : grid_size, 0 : grid_size]
xyz = jnp.stack(
[
jnp.array(x.ravel()) / jnp.maximum(1, grid_size - 1),
jnp.array(y.ravel()) / jnp.maximum(1, grid_size - 1),
jnp.array(z.ravel()) / jnp.maximum(1, grid_size - 1),
]
).transpose()
# Instantiates PointCloud with batch_size argument
geometry_pointcloud = pointcloud.PointCloud(
xyz, xyz, epsilon=epsilon, batch_size=1024
)

# Runs on GPU
%timeit sinkhorn.sinkhorn(geometry_grid, a=a, b=b).reg_ot_cost.block_until_ready()
out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b)
print(
f"Regularised optimal transport cost using Grid = {out_grid.reg_ot_cost}\n"
)

%timeit sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b).reg_ot_cost.block_until_ready()
out_pointcloud = sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b)
print(
f"Regularised optimal transport cost using Pointcloud = {out_pointcloud.reg_ot_cost}"
)

1 loops, best of 3: 35.5 ms per loop
Regularised optimal transport cost using Grid = 0.34500643610954285

1 loops, best of 3: 11.4 s per loop
Regularised optimal transport cost using PointCloud = 0.34500643610954285