# Grid geometry

## Contents

# Grid geometry#

In this tutorial, we cover how to instantiate and use `Grid`

.

`Grid`

is a geometry that is useful when the probability measures are supported on a \(d\)-dimensional cartesian grid, i.e. a cartesian product of \(d\) lists of values, each list \(i\) being of size \(n_i\). The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in \(\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)\) where
\(\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\).

The advantage of using `Grid`

over `PointCloud`

for such cases is that the computational cost is \(O(N^{(1+1/d)})\) instead of \(O(N^2)\) where \(N\) is the total number of points in the grid.

```
[1]:
```

```
import jax
import jax.numpy as jnp
import numpy as np
from ott.core import sinkhorn
from ott.geometry import costs
from ott.geometry import grid
from ott.geometry import pointcloud
```

## Uses `Grid`

with the argument `x`

#

In this example, the argument `x`

is a list of \(3\) vectors, of varying sizes \(\{n_1, n_2, n_3\}\), that describe the locations of the grid. The resulting grid is the Cartesian product of these vectors. `a`

and `b`

are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube.

```
[2]:
```

```
rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 5)
grid_size = (5, 6, 7)
x = [jax.random.uniform(keys[0], (grid_size[0],)),
jax.random.uniform(keys[1], (grid_size[1],)),
jax.random.uniform(keys[2], (grid_size[2],))]
a = jax.random.uniform(keys[3], grid_size)
b = jax.random.uniform(keys[4], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
```

Instantiates `Grid`

and calculates the regularized optimal transport cost.

```
[3]:
```

```
geom = grid.Grid(x=x, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
```

```
Regularised optimal transport cost = 0.30520981550216675
```

## Uses `Grid`

with the argument `grid_size`

#

In this example, the grid is described as points regularly sampled in \([0, 1]\). `a`

and `b`

are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube \([0, 1]^3\).

```
[4]:
```

```
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6, 7)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
```

Instantiates `Grid`

and calculates the regularized optimal transport cost.

```
[5]:
```

```
geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
```

```
Regularised optimal transport cost = 0.3816334307193756
```

## Varies the cost function in each dimension#

Instead of the squared Euclidean distance, we will use a squared Mahalanobis distance, where the covariance matrix is diagonal. This example illustrates the possibility of choosing a cost function for each dimension.

```
[6]:
```

```
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
```

We want to use as covariance matrix for the Mahalanobis distance the diagonal 2x2 matrix, with \([1/2, 1]\) as diagonal. We create an additional costs.CostFn.

```
[7]:
```

```
@jax.tree_util.register_pytree_node_class
class EuclideanTimes2(costs.CostFn):
"""The cost function corresponding to the squared euclidean distance times 2."""
def norm(self, x):
return jnp.sum(x ** 2, axis=-1) * 2
def pairwise(self, x, y):
return - 2 * jnp.sum(x * y) * 2
cost_fns = [EuclideanTimes2(), costs.Euclidean()]
```

Instantiates `Grid`

and calculates the regularized optimal transport cost.

```
[8]:
```

```
geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
```

```
Regularised optimal transport cost = 0.3241420388221741
```

## Compares runtime between using `Grid`

and `PointCloud`

#

The squared euclidean distance is an example of separable distance for which it is possible to use `Grid`

instead of `PointCloud`

. In this case, using `Grid`

over `PointCloud`

as geometry in the context of regularised optimal transport presents a computational advantage, as the computational cost of applying a kernel in Sinkhorn steps is of the order of \(O(N^{(1+1/d)})\) instead of the naive \(O(N^2)\) complexity, where \(N\) is the total number of points in the grid and
\(d\) its dimension. In this example, we can see that for the same grid size and points, the computational runtime of sinkhorn with `Grid`

is smaller than with `PointCloud`

.

```
[9]:
```

```
epsilon = 0.1
grid_size = (50, 50, 50)
rng = jax.random.PRNGKey(2)
keys = jax.random.split(rng, 2)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
xyz = jnp.stack([
jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
]).transpose()
# Instantiates PointCloud with argument 'online=True'
geometry_pointcloud = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon, online=True)
# Runs on GPU
%timeit sinkhorn.sinkhorn(geometry_grid, a=a, b=b).reg_ot_cost.block_until_ready()
out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b)
print(f'Regularised optimal transport cost using Grid = {out_grid.reg_ot_cost}\n')
%timeit sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b).reg_ot_cost.block_until_ready()
out_pointcloud = sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b)
print(f'Regularised optimal transport cost using Pointcloud = {out_pointcloud.reg_ot_cost}')
```

```
1 loops, best of 3: 35.5 ms per loop
Regularised optimal transport cost using Grid = 0.34500643610954285
1 loops, best of 3: 11.4 s per loop
Regularised optimal transport cost using PointCloud = 0.34500643610954285
```