Grid Geometry#
The Grid
geometry was designed having in mind the many applications that use the so-called Eulerian description of probability measures. In such applications, probability measures are seen as histograms supported on a \(d\)-dimensional Cartesian grid, and not as point clouds.
A Grid
geometry instantiates a Cartesian product of \(d\) lists of values, each list of index \(i\) being of size \(n_i\). That Cartesian product has a total number of \(N:=\prod_i n_i\) possible locations in \(\mathbb{R}^d\).
A Grid
geometry also assumes that the ground cost between points in the grid is separable: For two points \(x, y\) in that grid, the cost must be of the form \(\text{cost}(x,y) = \sum_{i=1}^d \text{cost}_i(x_i, y_i)\) where \(\text{cost}_i: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\). As a result, a \(d\)-dimensional Grid
expects a tuple of up to \(d\) A CostFn
objects, each describing a cost between two real-values.
The advantage of using Grid
over PointCloud
is that fundamental operations, such as applying the \(N\times N\) square cost matrix of all pairwise distances between the \(N\) points in the grid, as well as its kernel, can be efficiently carried out in \(O(N^{(1+1/d)})\) operations, with a similar memory footprint, rather than instantiating naively those matrices as \(N^2\) blocks.
import jax
import jax.numpy as jnp
import numpy as np
from ott.geometry import costs, grid, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
Create Grid
with the x
argument#
In this example, the argument x
is a list of \(d=3\) vectors \(x_1, x_2, x_3\), of varying sizes \(\{n_1, n_2, n_3\}\), that describe the locations of the grid along each dimension. The resulting grid is the Cartesian product of these vectors (seen each as a list of values), namely \(\{u\in x_1\}\times \{u\in x_2,\} \times \{u\in x_3\}\). Assuming each vector is formed with distinct coordinates, that Cartesian product holds \(N = n_1 n_2 n_3 = 5 \times 6 \times 7 = 210\) distinct points in the example below. a
and b
are two histograms on that grid, namely probability vectors of size \(N\). Note that, to showcase the versatility of the Grid
API, the grid is here is irregularly spaced, since locations along each dimension are random.
keys = jax.random.split(jax.random.key(0), 5)
grid_size = (5, 6, 7)
x = [
jax.random.uniform(keys[0], (grid_size[0],)),
jax.random.uniform(keys[1], (grid_size[1],)),
jax.random.uniform(keys[2], (grid_size[2],)),
]
We have now all ingredients to create a geom
object that will describe that grid. Since we do not specify any cost function for each dimension, denoted as \(\text{cost}_i\) in the formula above, our instantiation will default to SqEuclidean
(between real numbers) for each dimension. Naturally this is mathematically equivalent to running computations with the a point cloud object, instantiated with a usual squared-Euclidean distance between vectors in \(\mathbb{R}^3\). We will get back to that approach later in this tutorial.
geom = grid.Grid(x=x)
We can now generate two histograms a
and b
on that grid. These have total size equal to \(N\), and are unfolded to have a unified vector API for probability weights. They will, however, be reshaped within computations as tensors of grid_size
shape.
a = jax.random.uniform(keys[3], grid_size)
b = jax.random.uniform(keys[4], grid_size)
a = a.ravel() / jnp.sum(a) # Normalize to have unit total mass.
b = b.ravel() / jnp.sum(b) # "
We now solve the OT problem between weights a
to b
, by running a Sinkhorn
solver, to output the regularized optimal transport cost. The example below illustrates how ott
delegates low-level geometric computations to the geom
objects, and never to the Sinkhorn
solver.
prob = linear_problem.LinearProblem(geom, a=a, b=b)
solver = sinkhorn.Sinkhorn()
out = solver(prob)
print(f"Regularized OT cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.08210793137550354
Create Grid
with the grid_size
argument#
When only the grid_size
shape tensor is specified, the grid is assumed to be regular, and locations along each axis are assumed to be of the form \(j/(n_i-1)\) for \(0\leq j\leq n_i-1\). This will therefore result in a simple grid in the 3-D hypercube \([0, 1]^3\). As expected, even when keeping the same histograms a
and b
, the OT cost is different, since we have shifted points.
geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
# We recycle the same probability vectors
prob = linear_problem.LinearProblem(geom, a=a, b=b)
out = solver(prob)
print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 0.28149110078811646
Different cost_fn
for each dimension#
In the examples above, we have assumed that the cost function \(\text{cost}_i\) was the squared Euclidean distance. To illustrate how a different cost function can be chosen for each dimension, we implement an exotic custom cost function between real numbers.
@jax.tree_util.register_pytree_node_class
class MyCost(costs.CostFn):
"""An unusual cost function."""
def norm(self, x: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(x**3 + jnp.cos(x) ** 2, axis=-1)
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
return (
self.norm(x)
+ self.norm(y)
- jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2
)
Using the same grid size, we redefine Grid
with these new cost functions, and recompute a regularized optimal transport cost.
cost_fns = [MyCost(), costs.SqEuclidean(), MyCost()] # 1 for each dimension.
geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
out = solver(prob)
print(f"Regularized optimal transport cost = {out.reg_ot_cost}")
Regularized optimal transport cost = 1.2038968801498413
Compare runtime between using Grid
and PointCloud
#
Why use a Grid
geometry instead of a PointCloud
geometry, defined with \(N\) points?. In addition to convenience, the main advantage of Grid
geometries is computational.
Indeed, the Sinkhorn
algorithm applies a kernel operator, derived directly from the geometry, at each of its steps. Grid
geometries apply that kernel in \(O(N^{(1+1/d)})\) operations, whereas PointCloud
require a \(O(N^2)\) complexity, where \(N\) is the total number of points in the grid and \(d\) its dimension. Note that these two approaches are numerically equivalent, it’s just that the former is more efficient than the latter.
You can see this by yourself in the example below. We instantiate two grid geometries that are mathematically equivalent (describing the same points), and show that running Sinkhorn
iterations with a Grid
is 180 times faster compared to a naive PointCloud
.
grid_size = (37, 29, 43)
keys = jax.random.split(jax.random.key(2), 2)
a = jax.random.uniform(keys[0], grid_size)
b = jax.random.uniform(keys[1], grid_size)
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
print("Total size of grid: ", jnp.product(jnp.array(grid_size)))
Total size of grid: 46139
# Instantiates Grid
geometry_grid = grid.Grid(grid_size=grid_size)
prob_grid = linear_problem.LinearProblem(geometry_grid, a=a, b=b)
%timeit solver(prob_grid).reg_ot_cost.block_until_ready()
out_grid = solver(prob_grid)
print(
f"Regularized optimal transport cost using Grid = {out_grid.reg_ot_cost}\n"
)
2.03 s ± 17.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Regularized optimal transport cost using Grid = 0.10972004383802414
# List all 3D points in cartesian product.
x, y, z = np.mgrid[0 : grid_size[0], 0 : grid_size[1], 0 : grid_size[2]]
xyz = jnp.stack(
[
jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
]
).transpose()
# Instantiates PointCloud with `batch_size` argument.
# Computations require being run in batches, otherwise memory would
# overflow. This is achieved by setting `batch_size` to 1024.
geometry_pointcloud = pointcloud.PointCloud(xyz, xyz, batch_size=1024)
prob_pointcloud = linear_problem.LinearProblem(geometry_pointcloud, a=a, b=b)
%timeit solver(prob_pointcloud).reg_ot_cost.block_until_ready()
out_pointcloud = solver(prob_pointcloud)
print(
f"Regularized optimal transport cost using PointCloud = {out_pointcloud.reg_ot_cost}"
)
6min 5s ± 5.8 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Regularized optimal transport cost using PointCloud = 0.10972030460834503