Sinkhorn Barycenters#

This NB is run on colab using a GPU.

Import Toolboxes#

[ ]:
import numpy as np
!pip install nilearn
import nilearn
from nilearn import datasets
from nilearn import plotting
from nilearn import image
from nilearn.image import get_data
[ ]:
!pip install ott-jax
# import JAX and OTT
import jax
import jax.numpy as jnp
import ott
from ott.geometry import grid
from ott.core import discrete_barycenter
[ ]:
# misc
import matplotlib.pyplot as plt
plt.rc('font', size = 20)

Import neuroimaging data using nilearn.#

We recover a few MRI datapoints…

[ ]:
n_subjects = 4
dataset_files = datasets.fetch_oasis_vbm(n_subjects=n_subjects)
gm_imgs = np.array(dataset_files.gray_matter_maps)

… and plot their gray matter densities.

[ ]:
for i in range(n_subjects):
  plotting.plot_epi(gm_imgs[i])
  plt.show()
../_images/notebooks_Sinkhorn_Barycenters_8_0.png
../_images/notebooks_Sinkhorn_Barycenters_8_1.png
../_images/notebooks_Sinkhorn_Barycenters_8_2.png
../_images/notebooks_Sinkhorn_Barycenters_8_3.png

Represent data as histograms#

We normalize those gray matter densities so that they sum to 1, and check their size.

[ ]:
a = jnp.array(get_data(gm_imgs)).transpose((3, 0, 1, 2))
grid_size = a.shape[1:4]
a = a.reshape((n_subjects, -1)) + 1e-2
a = a / np.sum(a,axis=1)[:, np.newaxis]
print('Grid size: ', grid_size)
Grid size:  (91, 109, 91)

Instantiate a grid geometry to compute \(W_p^p\)#

We instantiate the grid geometry corresponding to these data points, living in a space of dimension \(91 \times 109 \times 91\), for a total total dimension \(d=902629\). Rather than stretch these voxel histograms and put them in the \([0,1]^3\) hypercube, we use a simpler rescaled grid, \([0, 0.9] \times [0, 1.08] \times [0, 0.9]\), with increments of 1/100.

We endow points on that 3D grid with the Custom cost function defined below: we use a \(p\)-norm, with \(p\) slighter larger than 1 following previous works [1, 2] on brain signals.

We use an 𝜀 scheduler that will decrease the regularization strength from 0.1 down to 1e-4 with a decay factor of 0.95.

[ ]:
@jax.tree_util.register_pytree_node_class
class Custom(ott.geometry.costs.CostFn):
  """Custom function."""
  def pairwise(self, x, y):
    return jnp.sum(jnp.abs(x-y) ** 1.1)

# Instantiate Grid Geometry of suitable size, epsilon parameter and cost.
g_grid = grid.Grid(x = [jnp.arange(0, n)/100 for n in grid_size],
                   cost_fns=[Custom()],
                   epsilon=ott.geometry.epsilon_scheduler.Epsilon(
                       target=1e-4, init=1e-1, decay=0.95))

Compute their regularized \(W_p^p\) iso-barycenter#

A small trick: If we jit and run the discrete_barycenter function with a small 𝜀 directly, it takes ages because it’s both solving a hard problem and jitting the function at the same time. It’s slightly more efficient to jit it with an easy problem, and run next the problem with the 𝜀 target we need.

[ ]:
%%time
g_grid._epsilon.target=1
barycenter = discrete_barycenter.discrete_barycenter(g_grid, a)

CPU times: user 3.8 s, sys: 1.99 s, total: 5.78 s
Wall time: 5.27 s
[ ]:
%%time
g_grid._epsilon.target=1e-4
barycenter = discrete_barycenter.discrete_barycenter(g_grid, a)
CPU times: user 7min 16s, sys: 6min 50s, total: 14min 7s
Wall time: 14min 6s

Plot decrease of marginal error#

The computation of the barycenter of \(N\) histograms involves [3] the resolution of \(N\) OT problems pointing towards the same, but unknown, marginal. The convergence of that algorithm can be monitored by evaluating the distance between the marginals of these different transport matrices w.r.t. that same, common marginal. Upon convergence that should be close to 0.

[ ]:
plt.figure(figsize=(8,5))
errors = barycenter.errors[:-1]
plt.plot(np.arange(errors.size) * 10, errors, lw=3)
plt.title('Marginal error decrease in barycenter computation')
plt.yscale("log")
plt.xlabel('Iterations')
plt.ylabel('Marginal Error')
plt.show()
../_images/notebooks_Sinkhorn_Barycenters_17_0.png

Plot the barycenter itself#

[ ]:
def data_to_nii(x):
  return nilearn.image.new_img_like(
      gm_imgs[0], data=np.array(x.reshape(grid_size)))
plotting.plot_epi(data_to_nii(barycenter.histogram))
plt.show()
../_images/notebooks_Sinkhorn_Barycenters_19_0.png

Euclidean barycenter, for reference#

[ ]:
plotting.plot_epi(data_to_nii(np.mean(a,axis=0)))
<nilearn.plotting.displays.OrthoSlicer at 0x7f48503fb710>
../_images/notebooks_Sinkhorn_Barycenters_21_1.png