Low-Rank GW#
We use the low-rank (LR) Gromov-Wasserstein solver, proposed by [Scetbon et al., 2022], as a follow up to the LR Sinkhorn solver in [Scetbon et al., 2021], see Low-rank Sinkhorn for more information.
import sys
if "google.colab" in sys.modules:
!pip install -q git+https://github.com/ott-jax/ott@main
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ott.geometry import pointcloud
from ott.problems.quadratic import quadratic_problem
from ott.solvers.quadratic import gromov_wasserstein
def create_points(
rng: jax.random.PRNGKeyArray, n: int, m: int, d1: int, d2: int
):
rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, d1))
y = jax.random.uniform(rngs[1], (m, d2))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a = a / jnp.sum(a)
b = b / jnp.sum(b)
z = jax.random.uniform(rngs[4], (m, d1))
return x, y, a, b, z
rng = jax.random.PRNGKey(0)
n, m, d1, d2 = 24, 17, 2, 3
x, y, a, b, z = create_points(rng, n, m, d1, d2)
Create two point clouds of heterogeneous size, and add a third geometry to formulate a fused problem [Vayer et al., 2020].
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, z)
prob = quadratic_problem.QuadraticProblem(
geom_xx,
geom_yy,
geom_xy=geom_xy,
a=a,
b=b,
)
Solve the problem using the LRSinkhorn
solver class.
solver = gromov_wasserstein.GromovWasserstein(rank=6)
ot_gwlr = solver(prob)
Run it with the widespread entropic GromovWasserstein
solver for the sake of comparison.
solver = gromov_wasserstein.GromovWasserstein(epsilon=0.05)
ot_gw = solver(prob)
One can notice that their outputs are quantitatively similar.
def plot_ot(ot, leg):
plt.imshow(ot.matrix, cmap="Purples")
plt.colorbar()
plt.title(leg + " cost: " + str(ot.costs[ot.costs > 0][-1]))
plt.show()
plot_ot(ot_gwlr, "Low rank")
plot_ot(ot_gw, "Entropic")

