Grid geometry
Contents
Grid geometry#
In this tutorial, we cover how to instantiate and use Grid
.
Grid
is a geometry that is useful when the probability measures are supported on a \(d\)-dimensional cartesian grid, i.e. a cartesian product of \(d\) lists of values, each list \(i\) being of size \(n_i\). The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in \(\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)\) where
\(\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\).
The advantage of using Grid
over PointCloud
for such cases is that the computational cost is \(O(N^{(1+1/d)})\) instead of \(O(N^2)\) where \(N\) is the total number of points in the grid.
[1]:
import jax
import jax.numpy as jnp
import numpy as np
from ott.core import sinkhorn
from ott.geometry import costs
from ott.geometry import grid
from ott.geometry import pointcloud
Uses Grid
with the argument x
#
In this example, the argument x
is a list of \(3\) vectors, of varying sizes \(\{n_1, n_2, n_3\}\), that describe the locations of the grid. The resulting grid is the Cartesian product of these vectors. a
and b
are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube.
[2]:
rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 5)
grid_size = (5, 6, 7)
x = [jax.random.uniform(keys[0], (grid_size[0],)),
jax.random.uniform(keys[1], (grid_size[1],)),
jax.random.uniform(keys[2], (grid_size[2],))]
a = jax.random.uniform(keys[3], grid_size)
b = jax.random.uniform(keys[4], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
Instantiates Grid
and calculates the regularized optimal transport cost.
[3]:
geom = grid.Grid(x=x, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
Regularised optimal transport cost = 0.30520981550216675
Uses Grid
with the argument grid_size
#
In this example, the grid is described as points regularly sampled in \([0, 1]\). a
and b
are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube \([0, 1]^3\).
[4]:
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6, 7)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
Instantiates Grid
and calculates the regularized optimal transport cost.
[5]:
geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
Regularised optimal transport cost = 0.3816334307193756
Varies the cost function in each dimension#
Instead of the squared Euclidean distance, we will use a squared Mahalanobis distance, where the covariance matrix is diagonal. This example illustrates the possibility of choosing a cost function for each dimension.
[6]:
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
We want to use as covariance matrix for the Mahalanobis distance the diagonal 2x2 matrix, with \([1/2, 1]\) as diagonal. We create an additional costs.CostFn.
[7]:
@jax.tree_util.register_pytree_node_class
class EuclideanTimes2(costs.CostFn):
"""The cost function corresponding to the squared euclidean distance times 2."""
def norm(self, x):
return jnp.sum(x ** 2, axis=-1) * 2
def pairwise(self, x, y):
return - 2 * jnp.sum(x * y) * 2
cost_fns = [EuclideanTimes2(), costs.Euclidean()]
Instantiates Grid
and calculates the regularized optimal transport cost.
[8]:
geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
out = sinkhorn.sinkhorn(geom, a=a, b=b)
print(f'Regularised optimal transport cost = {out.reg_ot_cost}')
Regularised optimal transport cost = 0.3241420388221741
Compares runtime between using Grid
and PointCloud
#
The squared euclidean distance is an example of separable distance for which it is possible to use Grid
instead of PointCloud
. In this case, using Grid
over PointCloud
as geometry in the context of regularised optimal transport presents a computational advantage, as the computational cost of applying a kernel in Sinkhorn steps is of the order of \(O(N^{(1+1/d)})\) instead of the naive \(O(N^2)\) complexity, where \(N\) is the total number of points in the grid and
\(d\) its dimension. In this example, we can see that for the same grid size and points, the computational runtime of sinkhorn with Grid
is smaller than with PointCloud
.
[9]:
epsilon = 0.1
grid_size = (50, 50, 50)
rng = jax.random.PRNGKey(2)
keys = jax.random.split(rng, 2)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
xyz = jnp.stack([
jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
]).transpose()
# Instantiates PointCloud with argument 'online=True'
geometry_pointcloud = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon, online=True)
# Runs on GPU
%timeit sinkhorn.sinkhorn(geometry_grid, a=a, b=b).reg_ot_cost.block_until_ready()
out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b)
print(f'Regularised optimal transport cost using Grid = {out_grid.reg_ot_cost}\n')
%timeit sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b).reg_ot_cost.block_until_ready()
out_pointcloud = sinkhorn.sinkhorn(geometry_pointcloud, a=a, b=b)
print(f'Regularised optimal transport cost using Pointcloud = {out_pointcloud.reg_ot_cost}')
1 loops, best of 3: 35.5 ms per loop
Regularised optimal transport cost using Grid = 0.34500643610954285
1 loops, best of 3: 11.4 s per loop
Regularised optimal transport cost using PointCloud = 0.34500643610954285