Wasserstein Barycenter

Wasserstein Barycenter#

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from ott.problems.linear import barycenter_problem
from ott.solvers.linear import continuous_barycenter, sinkhorn

We illustrate in this notebook how to use the FreeWassersteinBarycenter solver to compute the Wasserstein barycenter of either one or multiple probability distributions. We start by generating a few 2D point clouds of varying support size.

ns = (193, 20, 27, 5)  # number of points per cloud
offsets = (
    jnp.array((3.0, 3.0)),
    jnp.array((-3.0, 3.0)),
    jnp.array((0.0, -3.0)),
    jnp.array((0.0, 5.0)),
)  # offsets for each cloud
d = 2  # dimension of the points
k = 13  # number of points in the barycenter's support

keys = jax.random.split(jax.random.key(0), 4)

point_clouds = []
weights = []

for key, n, offset in zip(keys, ns, offsets):
    k1, k2 = jax.random.split(key)
    point_clouds.append(jax.random.normal(k1, (n, d)) + offset)
    weight = jax.random.uniform(k2, (n,))
    weight /= weight.sum()
    weights.append(weight)

flattened_points = jnp.concatenate(point_clouds, axis=0)
flattened_weights = jnp.concatenate(weights, axis=0)

A Wasserstein barycenter problem is defined by a list of (weighted) points clouds (here passed as a flattened array, providing boundaries in num_per_segment) and an epsilon regularization (here set automatically following a scaling rule, since we pass None).

bprob = barycenter_problem.FreeBarycenterProblem(
    y=flattened_points, b=flattened_weights, num_per_segment=ns, epsilon=None
)

Next, we instantiate the solver that will be used to compute the barycenter. We rely on default parameters for number of iterations / tolerance, but we must set the linear solver used within the inner loop iterations, here the Sinkhorn algorithm.

solver = continuous_barycenter.FreeWassersteinBarycenter(
    linear_solver=sinkhorn.Sinkhorn()
)

We jit the solver first and apply it to the problem above.

jitted_solver = jax.jit(solver, static_argnames="bar_size")
out = jitted_solver(bprob, bar_size=k)

The out object contains relevant information about the barycenter itself, notably its points and how the cost evolved throughout iterations

print("Shape of barycenter : ", out.x.shape)
print(
    "Convergence of inner loop iterations :", out.all_linear_solvers_converged
)
print("Converged in:", out.num_iters, "outer iterations.")
print("Objective: ", out.costs_along_iterations)
Shape of barycenter :  (13, 2)
Convergence of inner loop iterations : True
Converged in: 5 outer iterations.
Objective:  [23.27366  15.290319 15.27664  15.271326 15.267798]

Visualize results#

base_size = 500
for i, (weight, point_cloud) in enumerate(zip(weights, point_clouds)):
    plt.scatter(
        point_cloud[:, 0],
        point_cloud[:, 1],
        s=base_size * weight,
        label="point cloud " + str(i + 1),
    )
plt.scatter(
    out.x[:, 0],
    out.x[:, 1],
    s=base_size * out.a,
    c="black",
    marker="s",
    label="Wasserstein Barycenter",
)
plt.legend()
plt.grid(True)
plt.show()
../../_images/f9d92d249f519227a0ca3e91c7e760d80aed8e20a4f6499e3e0866f5b67aa2af.png

Note that the Wasserstein barycenter problem can also be instantiated on a single measure, to generate a uniformly weighted variant of the \(k\)-means algorithm as described in [Cuturi and Doucet, 2014]. This can be done trivially by describing a problem with a single measure.

bprob = barycenter_problem.FreeBarycenterProblem(
    y=point_clouds[0], b=weights[0], num_per_segment=(ns[0],)
)
out = jitted_solver(bprob, bar_size=k)

When displaying the barycenter, one notices that it is realized by a measure with a smaller support, all points being uniformly weighted.

plt.scatter(
    point_clouds[0][:, 0],
    point_clouds[0][:, 1],
    s=base_size * weights[0],
)
plt.scatter(
    out.x[:, 0],
    out.x[:, 1],
    s=base_size * out.a,
    c="black",
    marker="s",
    label="Uniform Wasserstein Barycenter",
)
plt.legend()
plt.grid(True)
plt.show()
../../_images/f57bc6a87066b18523bea55ee659a5958981e27cc484626046ae36e86de20923.png