# Gromov-Wasserstein

## Contents

%load_ext autoreload


# Gromov-Wasserstein#

In this tutorial, we present the GromovWasserstein solver. The goal of the GW problem is to match two point clouds, taken from different spaces endowed with their own geometries. We illustrate this use case by aligning a 2-d to a 3-d point clouds, see also GW for Multi-omics for a more challenging application to single-cell omics.

import sys

!pip install -q git+https://github.com/ott-jax/ott@main

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from IPython import display
from matplotlib import animation, cm

from ott.geometry import pointcloud


## Matching across spaces#

We use the GromovWasserstein solver to match a spiral in 2-d to a Swiss roll in 3-d, generated below

def sample_spiral(
):
angles = jnp.linspace(min_angle, max_angle, n)
data = []
noise = jax.random.normal(key, (2, n)) * noise
for i in range(n):
x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])
y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])
data.append([x, y])
data = jnp.array(data)
return data

def sample_swiss_roll(
):
spiral = sample_spiral(
)
third_axis = jax.random.uniform(key[1], (n, 1)) * length
swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))
return swiss_roll

def plot(
swiss_roll, spiral, colormap_angles_swiss_roll, colormap_angles_spiral
):
fig = plt.figure(figsize=(11, 5))
ax.scatter(spiral[:, 0], spiral[:, 1], c=colormap_angles_spiral)
ax.grid()
ax = fig.add_subplot(1, 2, 2, projection="3d")
ax.view_init(7, -80)
ax.scatter(
swiss_roll[:, 0],
swiss_roll[:, 1],
swiss_roll[:, 2],
c=colormap_angles_swiss_roll,
)
plt.show()

# Generation parameters
n_spiral = 400
n_swiss_roll = 500
length = 10
noise = 0.8
min_angle = 0
max_angle = 9
angle_shift = 3

# Seed
rng = jax.random.PRNGKey(14)
rng, *subrngs = jax.random.split(rng, 4)

spiral = sample_spiral(
n_spiral,
key=subrngs[0],
min_angle=min_angle + angle_shift,
max_angle=max_angle + angle_shift,
noise=noise,
)
swiss_roll = sample_swiss_roll(
n_swiss_roll,
key=subrngs[1:],
length=length,
min_angle=min_angle,
max_angle=max_angle,
)

plot(swiss_roll, spiral, "blue", "green")


We define two point clouds to describe each of these point clouds, each using (by default) the SqEuclidean cost function.

# Instantiate the Quadratic Alignment Problem
geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)

# Instantiate a jitt'ed Gromov-Wasserstein solver
solver = jax.jit(
gromov_wasserstein.GromovWasserstein(
epsilon=100.0, max_iterations=20, store_inner_errors=True
)
)
out = solver(prob)

has_converged = bool(out.linear_convergence[out.n_iters - 1])
print(f"{out.n_iters} outer iterations were needed.")
print(f"The last Sinkhorn iteration has converged: {has_converged}")
print(f"The outer loop of Gromov Wasserstein has converged: {out.converged}")
print(f"The final regularized GW cost is: {out.reg_gw_cost:.3f}")

5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 1183.611


The resulting transport matrix between the two point clouds is as follows:

transport = out.matrix
fig = plt.figure(figsize=(8, 6))
plt.imshow(transport, cmap="Purples")
plt.xlabel(
"IDs of samples from the Swiss roll", fontsize=14
)  # IDs are ordered from center to outer part
plt.ylabel(
"ID of samples from the spiral", fontsize=14
)  # IDs are ordered from center to outer part
plt.colorbar()
plt.show()


The larger the regularization parameter epsilon is, the more diffuse the transport matrix becomes, as we can see in the animation below.

# Animates the transport matrix
fig = plt.figure(figsize=(8, 6))
im = plt.imshow(transport, cmap="Purples")
plt.xlabel(
"IDs of samples from the Swiss roll", fontsize=14
)  # IDs are ordered from center to outer part
plt.ylabel(
"IDs of samples from the spiral", fontsize=14
)  # IDs are ordered from center to outer part
plt.colorbar()

# Initialization function
def init():
im.set_data(np.zeros(transport.shape))
return [im]

# Animation function
def animate(i):
array = im.get_array()
geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)

solver = gromov_wasserstein.GromovWasserstein(epsilon=i, max_iterations=20)

out = solver(prob)

im.set_array(out.matrix)
im.set_clim(0, jnp.max(out.matrix[:]))
return [im]

# Call the animator
anim = animation.FuncAnimation(
fig,
animate,
init_func=init,
frames=[70.0, 100.0, 200.0, 500.0, 750.0, 1000.0, 2000.0, 10000.0, 50000.0],
interval=1500,
blit=True,
)

html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()