Grid Geometry#
In this tutorial, we will cover how to instantiate and use Grid
geometry.
Grid
is a geometry that is useful when the probability measures are supported on a \(d\)-dimensional Cartesian grid, i.e., a Cartesian product of \(d\) lists of values, each list \(i\) being of size \(n_i\). The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in \(\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)\) where \(\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\).
The advantage of using Grid
over PointCloud
for such cases is that the computational cost is \(O(N^{(1+1/d)})\) instead of \(O(N^2)\) where \(N\) is the total number of points in the grid.
import sys
if "google.colab" in sys.modules:
!pip install -q git+https://github.com/ott-jax/ott@main
import jax
import jax.numpy as jnp
import numpy as np
from ott.geometry import costs, grid, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
Create Grid
with the x
argument#
In this example, the argument x
is a list of \(3\) vectors, of varying sizes \(\{n_1, n_2, n_3\}\), that describe the locations of the grid. The resulting grid is the Cartesian product of these vectors. a
and b
are two histograms in a grid of size \(5 \times 6 \times 7\) that lies in the 3-dimensional hypercube.
rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 5)
grid_size = (5, 6, 7)
x = [
jax.random.uniform(keys[0], (grid_size[0],)),
jax.random.uniform(keys[1], (grid_size[1],)),
jax.random.uniform(keys[2], (grid_size[2],)),
]
a = jax.random.uniform(keys[3], grid_size)
b = jax.random.uniform(keys[4], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
Instantiate Grid
and calculate the regularized optimal transport cost.
geom = grid.Grid(x=x, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
solver = sinkhorn.Sinkhorn()
out = solver(prob)
print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.20520979166030884
Create Grid
with the grid_size
argument#
In this example, the grid is described as points regularly sampled in \([0, 1]\). a
and b
are two histograms in a grid of size \(5 \times 6 \times 7\) that lies in the 3-dimensional hypercube \([0, 1]^3\).
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6, 7)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
Instantiate Grid
and calculate the regularized optimal transport cost.
geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
out = solver(prob)
print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.281633585691452
Varies the cost function in each dimension#
Instead of the squared Euclidean distance, we will use a squared Mahalanobis distance, where the covariance matrix is diagonal. This example illustrates the possibility of choosing a cost function for each dimension.
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
grid_size = (5, 6)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
We want to use as covariance matrix for the Mahalanobis distance the diagonal \(2 \times 2\) matrix, with \([1/2, 1]\) as diagonal. We create an additional cost function
.
@jax.tree_util.register_pytree_node_class
class SqEuclideanTimes2(costs.CostFn):
"""The cost function corresponding to the squared SqEuclidean distance times 2."""
def norm(self, x):
return jnp.sum(x**2, axis=-1) * 2
def pairwise(self, x, y):
return -2 * jnp.sum(x * y) * 2
cost_fns = [SqEuclideanTimes2(), costs.SqEuclidean()]
Instantiate Grid
and calculate the regularized optimal transport cost.
geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
out = solver(prob)
print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.22414201498031616
Compare runtime between using Grid
and PointCloud
#
The squared euclidean distance is an example of separable distance for which it is possible to use Grid
instead of PointCloud
. In this case, using Grid
over PointCloud
as geometry in the context of regularized optimal transport presents a computational advantage, as the computational cost of applying a kernel in Sinkhorn
steps is of the order of \(O(N^{(1+1/d)})\) instead of the naive \(O(N^2)\) complexity, where \(N\) is the total number of points in the grid and \(d\), its dimension.
In this example, we can see that for the same grid size and points, the computational runtime of Sinkhorn with Grid
is smaller than with PointCloud
.
epsilon = 0.1
grid_size = (50, 50, 50)
rng = jax.random.PRNGKey(2)
keys = jax.random.split(rng, 2)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
prob_grid = linear_problem.LinearProblem(geometry_grid, a=a, b=b)
x, y, z = np.mgrid[0 : grid_size[0], 0 : grid_size[1], 0 : grid_size[2]]
xyz = jnp.stack(
[
jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
]
).transpose()
# Instantiates PointCloud with `batch_size` argument
geometry_pointcloud = pointcloud.PointCloud(
xyz, xyz, epsilon=epsilon, batch_size=1024
)
prob_pointcloud = linear_problem.LinearProblem(geometry_pointcloud, a=a, b=b)
%timeit solver(prob_grid).reg_ot_cost.block_until_ready()
out_grid = solver(prob_grid)
print(
f"Regularized optimal transport cost using Grid = {out_grid.reg_ot_cost}\n"
)
%timeit solver(prob_pointcloud).reg_ot_cost.block_until_ready()
out_pointcloud = solver(prob_pointcloud)
print(
f"Regularized optimal transport cost using Pointcloud = {out_pointcloud.reg_ot_cost}"
)
593 ms ± 2.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Regularized optimal transport cost using Grid = 0.24500826001167297