Low-rank GW

Low-rank GW#

We provide in this tutorial a minimal example that shows the low-rank (LR) Gromov-Wasserstein solver in action. This quadratic OT solver is presented in [Scetbon et al., 2022], as a follow up to the (linear) LR Sinkhorn solver in [Scetbon et al., 2021], see Low-rank Sinkhorn.

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from ott.geometry import pointcloud
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

Helper function to instantiate three point clouds in different dimensions

def create_points(rng: jax.Array, n: int, m: int, d1: int, d2: int):
    rngs = jax.random.split(rng, 5)
    x = jax.random.uniform(rngs[0], (n, d1))
    y = jax.random.uniform(rngs[1], (m, d2))
    a = jax.random.uniform(rngs[2], (n,))
    b = jax.random.uniform(rngs[3], (m,))
    a = a / jnp.sum(a)
    b = b / jnp.sum(b)
    z = jax.random.uniform(rngs[4], (m, d1))
    return x, y, a, b, z


rng = jax.random.key(0)
n, m, d1, d2 = 24, 17, 2, 3
x, y, a, b, z = create_points(rng, n, m, d1, d2)

After creating two point clouds in 2-d and 3-d, add a third arbitrary geometry to formulate a fused problem [Vayer et al., 2020].

geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, z)
prob = quadratic_problem.QuadraticProblem(
    geom_xx,
    geom_yy,
    geom_xy=geom_xy,
    a=a,
    b=b,
    fused_penalty=1.0,
)

Solve the problem using the LRGromovWasserstein solver.

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=6)
ot_gwlr = solver(prob)

Furthermore, we also run the entropic GromovWasserstein solver for the sake of comparison.

linear_solver = sinkhorn.Sinkhorn()
solver = gromov_wasserstein.GromovWasserstein(linear_solver, epsilon=0.05)
ot_gw = solver(prob)

One can notice that their outputs are quantitatively similar with respect to their primal cost.

def plot_ot(ot, leg):
    plt.imshow(ot.matrix, cmap="Purples")
    plt.colorbar()
    plt.title(f"{leg} cost: {ot.primal_cost:.4f}")
    plt.show()


plot_ot(ot_gwlr, "Low-rank")
plot_ot(ot_gw, "Entropic")
../../_images/a9b8751815e86a5e030c602ac71fba375e275b575f59a76ed3888d4321984012.png ../../_images/e3ce3477609c870c9ed463f5133a926581133bbf8a4a853e973d2d0772098be4.png