Gromov-Wasserstein

Gromov-Wasserstein#

In this tutorial, we present the GromovWasserstein entropic solver, as presented in [Peyré et al., 2016]. The GW problem [Mémoli, 2011] proposes to match two point clouds that may lie in two different spaces, each endowed with its own geometry, using a quadratic assignment problem [Koopmans and Beckmann, 1957].

We illustrate this by aligning a 2-d to a 3-d point cloud, see also GW for Multi-omics for a more realistic application to single-cell omics.

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt
from IPython import display
from matplotlib import animation

from ott.geometry import pointcloud
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn
from ott.solvers.quadratic import gromov_wasserstein

Matching across spaces#

We generate samples from a spiral in 2-d and a Swiss roll in 3-d.

def sample_spiral(
    n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0
):
    radius = jnp.linspace(min_radius, max_radius, n)
    angles = jnp.linspace(min_angle, max_angle, n)
    data = []
    noise = jax.random.normal(key, (2, n)) * noise
    for i in range(n):
        x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])
        y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])
        data.append([x, y])
    data = jnp.array(data)
    return data


def sample_swiss_roll(
    n, min_radius, max_radius, length, key, min_angle=0, max_angle=10, noise=0.1
):
    spiral = sample_spiral(
        n, min_radius, max_radius, key[0], min_angle, max_angle, noise
    )
    third_axis = jax.random.uniform(key[1], (n, 1)) * length
    swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))
    return swiss_roll


def plot(
    swiss_roll, spiral, colormap_angles_swiss_roll, colormap_angles_spiral
):
    fig = plt.figure(figsize=(11, 5))
    ax = fig.add_subplot(1, 2, 1)
    ax.scatter(spiral[:, 0], spiral[:, 1], c=colormap_angles_spiral)
    ax.grid()
    ax = fig.add_subplot(1, 2, 2, projection="3d")
    ax.view_init(7, -80)
    ax.scatter(
        swiss_roll[:, 0],
        swiss_roll[:, 1],
        swiss_roll[:, 2],
        c=colormap_angles_swiss_roll,
    )
    ax.set_adjustable("box")
    plt.show()
# Generation parameters
n_spiral = 400
n_swiss_roll = 500
length = 10
min_radius = 3
max_radius = 10
noise = 0.8
min_angle = 0
max_angle = 9
angle_shift = 3
# Seed
rng, *subrngs = jax.random.split(jax.random.key(14), 4)
spiral = sample_spiral(
    n_spiral,
    min_radius,
    max_radius,
    key=subrngs[0],
    min_angle=min_angle + angle_shift,
    max_angle=max_angle + angle_shift,
    noise=noise,
)
swiss_roll = sample_swiss_roll(
    n_swiss_roll,
    min_radius,
    max_radius,
    key=subrngs[1:],
    length=length,
    min_angle=min_angle,
    max_angle=max_angle,
)

plot(swiss_roll, spiral, "blue", "green")
../../_images/ffa8a7b36232cd50a3cb5116f6cce8ab860f240a14ff297fd52d70f3a88d0827.png

We define two point clouds to describe each of these point clouds, each using (by default) the SqEuclidean cost function.

# Instantiate the Quadratic Alignment Problem
geom_xx = pointcloud.PointCloud(x=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy)

# Instantiate a jitted Gromov-Wasserstein solver
linear_solver = sinkhorn.Sinkhorn()
solver = jax.jit(
    gromov_wasserstein.GromovWasserstein(
        linear_solver, epsilon=100, store_inner_errors=True
    )
)
out = solver(prob)

has_converged = bool(out.linear_convergence[out.n_iters - 1])
print(f"{out.n_iters} outer iterations were needed.")
print(f"The last Sinkhorn iteration has converged: {has_converged}")
print(f"The outer loop of Gromov Wasserstein has converged: {out.converged}")
print(f"The final regularized GW cost is: {out.reg_gw_cost:.3f}")
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 1183.617

The resulting transport matrix between the two point clouds can be displayed:

coupling_matrix = out.matrix
fig = plt.figure(figsize=(8, 6))
plt.imshow(coupling_matrix, cmap="Purples")
plt.xlabel(
    "IDs of samples from the Swiss roll", fontsize=14
)  # IDs are ordered from center to outer part
plt.ylabel(
    "ID of samples from the spiral", fontsize=14
)  # IDs are ordered from center to outer part
plt.colorbar()
plt.show()
../../_images/5f661a4c7d6ae3da865d654922727b1a493b2a6861ff1a006413107ec0d9d9d1.png

The larger the regularization parameter epsilon is, the more diffuse the transport matrix becomes, as we can see in the animation below.

# Animates the transport matrix
fig = plt.figure(figsize=(8, 6))
im = plt.imshow(coupling_matrix, cmap="Purples")
plt.xlabel(
    "IDs of samples from the Swiss roll", fontsize=14
)  # IDs are ordered from center to outer part
plt.ylabel(
    "IDs of samples from the spiral", fontsize=14
)  # IDs are ordered from center to outer part
plt.colorbar()


# Initialization function
def init():
    im.set_data(np.zeros(coupling_matrix.shape))
    return [im]


# Animation function
def animate(i):
    array = im.get_array()
    linear_solver = sinkhorn.Sinkhorn()
    solver = gromov_wasserstein.GromovWasserstein(
        linear_solver, epsilon=i, max_iterations=20
    )
    out = solver(prob)
    im.set_array(out.matrix)
    im.set_clim(0, jnp.max(out.matrix[:]))
    plt.title("epsilon = " + str(i))
    return [im]


# Call the animator
anim = animation.FuncAnimation(
    fig,
    animate,
    init_func=init,
    frames=[
        50.0,
        100.0,
        200.0,
        500.0,
        1000.0,
        2000.0,
        5000.0,
        10000.0,
        20000.0,
    ],
    interval=1500,
    blit=True,
)

html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()