Sinkhorn in All Flavors#

We provide in this example a detailed walk-through some of the functionalities of the Sinkhorn algorithm, including the computation of sinkhorn_divergence().

import sys

if "google.colab" in sys.modules:
    !pip install -q git+https://github.com/ott-jax/ott@main
import functools

import ott
from ott.geometry import geometry, pointcloud
from ott.solvers.linear import acceleration, sinkhorn
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

From texts to word histograms#

We adapt a keras NLP tutorial to preprocess raw text (here a subset of texts from the newsgroup20 database) and turn them into word embeddings histograms; see here for detailed preprocessing.

This helps us recover 635 histograms supported on 4000 words, each represented by a 50-dimensional vector.

# X contains 4000 word embeddings in dimension 50 , HIST a 653 x 4000 (row-normalized) matrix of histograms.
print(
    f"{HIST.shape[0]} texts supported on up to {HIST.shape[1]} words of dimension {X.shape[1]}"
)
653 texts supported on up to 4000 words of dimension 50

Pairwise Sinkhorn divergences#

Before setting a value for epsilon, let’s get a feel of what the PointCloud of embeddings looks like in terms of distances.

geom = pointcloud.PointCloud(X)
print(
    "median:",
    geom.median_cost_matrix,
    " mean:",
    geom.mean_cost_matrix,
    " max:",
    jnp.max(geom.cost_matrix),
)
median: 0.40351653  mean: 0.41272432  max: 1.4388262

Store \(4000 \times 4000\) cost matrix once and for all.

cost = geom.cost_matrix

We now define a jitted version of the sinkhorn_divergence() using a double vmap() to compute in one go the pairwise matrix of Sinkhorn divergence between two sets of histograms. Jitting is super important to achieve efficiency, don’t forget to wrap whatever you do with a jit() if you want to run at scale. Note also how we set a higher convergence threshold (default would be \(10^{-3}\)) to ensure slightly faster execution. Finally, the triple instantiation of cost is not a bug: it just describes that the computation of each sinkhorn_divergence() will trigger the computation of three regularized OT problems using Sinkhorn. In principle, each could have its own cost matrix. In this case they are all shared since the histograms are supported on the same set of words by design.

sink_div = jax.jit(
    jax.vmap(
        lambda HIST_1, HIST_2, cost, epsilon: jax.vmap(
            lambda hist_1, hist_2, cost, epsilon: sinkhorn_divergence(
                geometry.Geometry,
                cost,
                cost,
                cost,
                epsilon=epsilon,
                a=hist_1,
                b=hist_2,
                sinkhorn_kwargs={"threshold": 1e-2},
            ).divergence,
            in_axes=[0, None, None, None],
        )(HIST_1, HIST_2, cost, epsilon),
        in_axes=[None, 0, None, None],
    )
)

When setting epsilon to None, the algorithms will default to \(0.05\) of the mean distance described in the geometry. This is no magical number, but rather a simple statistic of the scale of the problem. We recommend that you tune epsilon by yourself, but using None might avoid common issues (such as running Sinkhorn with a very small epsilon while the cost matrices are large).

print("Default epsilon is: ", geom.epsilon)
Default epsilon is:  0.020636216

Compute now a pairwise \(30 \times 30\) matrix of Sinkhorn divergences (about 1000 divergences in total). We pick 30 different texts twice.

HIST_a = jnp.array(HIST[0:30])
HIST_b = jnp.array(HIST[-30:])
print(HIST_a.shape, HIST_b.shape, cost.shape)
(30, 4000) (30, 4000) (4000, 4000)

Dry run with large epsilon value to force jit() compilation before computing timings. This only makes sense within this tutorial.

DIV = sink_div(HIST_a, HIST_b, cost, 10000)

We now carry out divergence computations and plot their matrix for various epsilon.

DIV, ran_in = [], []
epsilons = [None, 5e-2, 1e-1]
for epsilon in epsilons:
    tic = time.perf_counter()
    DIV.append(sink_div(HIST_a, HIST_b, cost, epsilon).block_until_ready())
    toc = time.perf_counter()
    ran_in.append(toc - tic)

Notice how smaller epsilon has a huge impact on time (far longer). Larger epsilon values result in less spiked values with, however, a similar relative pattern. As epsilon grows, the sinkhorn_divergence() converges to a quantity directly related to the energy distance / MMD.

fig, axes = plt.subplots(1, 3, figsize=(12, 6))
fig.tight_layout()
axes = [axes[0], axes[1], axes[2]]
vmin = min([jnp.min(div) for div in DIV])
vmax = max([jnp.max(div) for div in DIV])

for epsilon, DIV_, ran_in_, ax_ in zip(epsilons, DIV, ran_in, axes):
    im = ax_.imshow(DIV_, vmin=vmin, vmax=vmax)
    eps = f" ({geom.epsilon:.4f})" if epsilon is None else ""
    ax_.set_title(
        r"$\varepsilon$ = " + str(epsilon) + eps + f"\n {ran_in_:.2f} s"
    )
    ax_.axis("off")

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)

plt.show()
../../_images/3450117c1e1221d9733522515338eaf669733e030cb2b3c134ef1471ebc70ac0.png

The impact of \(\varepsilon\) on convergence#

We study in more detail how epsilon impacts the convergence of the algorithm. We now restrict our attention to the Sinkhorn algorithm (not the divergence). We define first a my_sinkhorn function to handle computations of Sinkhorn with suitable parameters for this notebook.

def my_sinkhorn(
    geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, **kwargs
):
    prob = linear_problem.LinearProblem(geom, a=a, b=b)
    # recomputing error every iteration for plots
    # more iterations than the default setting to see full curves
    solver = sinkhorn.Sinkhorn(
        inner_iterations=1, max_iterations=10_000, **kwargs
    )
    return solver(prob)

We select now two text histograms. We will aim for texts that are supported on more than 1000 words each.

ind = jnp.argsort(jnp.sum(jnp.array(HIST) > 0, axis=1))
a, b = HIST[ind[-2]], HIST[ind[-1]]
print(
    f"Histogram a supported on {jnp.sum(a >0)} words, b on {jnp.sum(b >0)} words"
)
Histogram a supported on 1121 words, b on 1162 words

We start by looking more closely into the time needed for Sinkhorn to converge for various epsilon values.

out_eps, leg_eps = [], []
epsilons = [1e-3, 0.3 * 1e-2, 1e-2, 0.3 * 1e-1, 1e-1]
ran_in = np.zeros((len(epsilons),))
for i, epsilon in enumerate(epsilons):
    tic = time.perf_counter()
    out_eps.append(my_sinkorn(geometry.Geometry(cost, epsilon=epsilon), a, b))
    toc = time.perf_counter()
    ran_in[i] = toc - tic
    leg_eps.append(
        r"$\varepsilon$" + f"= {epsilon}, cost = {out_eps[-1].reg_ot_cost:.2f}"
    )

These execution times can then be plotted, to result in the following graph:

plt.plot(epsilons, ran_in, marker="s", markersize=10, linewidth=3)
../../_images/26aae8603f26a67928c0f9beb63dfe8d1232d559fb3c662c15d52c6e1e87e8a7.png

We now take a closer look at the actual convergence curves of the error of the Sinkhorn algorithm (i.e. marginal error). We introduce a plot_results function to visualize this convergence, see here).

We can now look more closely into epsilon’s impact. Obviously, the convergence is slower with smaller regularization: there is a trade-off between speed and how close to the original LP solution we want to be. In the absence of a strong opinion on how small regularization should, we advise that you start using larger epsilon, since this makes your life substantially easier!

plot_results(
    out_eps,
    leg_eps,
    title=r"Iterations needed to converge for various $\varepsilon$",
    xlabel="iterations",
    ylabel="error",
)
../../_images/76ae5df05ab58f37b42df8585b03bfbc1363ff15248570264b0fec9254908731.png

Speeding up Sinkhorn#

Fixed momentum#

[Thibault et al., 2021] proposed to use a Momentum term to (hopefully) accelerate the convergence of the Sinkhorn algorithm. This is controlled by the momentum parameter when calling Sinkhorn. We vary that parameter along with various epsilon regularization strengths. As can be seen below, a momentum parameter larger than 1.0 (also known as extrapolation or over-relaxation) helps, but can also be more unstable.

We first compute baseline curves for three \(\varepsilon\) values:

epsilons = [1e-4, 1e-3, 1e-2]
out_baseline, leg_baseline = [], []
for epsilon in epsilons:
    out_baseline.append(
        my_sinkorn(geometry.Geometry(cost, epsilon=epsilon), a, b)
    )
    leg_baseline.append("Baseline")

Test now using momentum values lower and larger than 1. Run computations first

out_mom, leg_mom = [], []
for i, epsilon in enumerate(epsilons):
    out_mom.append([out_baseline[i]])  # initialize with baseline
    leg_mom.append([leg_baseline[i]])  # initialize with baseline
    for mom in [0.8, 1.05, 1.1, 1.3]:
        out_mom[i].append(
            my_sinkorn(
                geometry.Geometry(cost, epsilon=epsilon),
                a,
                b,
                momentum=acceleration.Momentum(value=mom),
            )
        )
        leg_mom[i].append(f"Momentum : {mom}")

Plot them next.

for i, epsilon in enumerate(epsilons):
    plot_results(
        out_mom[i],
        leg_mom[i],
        title=r"Fixed Momentum, $\varepsilon$=" + str(epsilon),
        xlabel="iterations",
        ylabel="error",
    )
../../_images/57dcddf32732b75454ca0f32d0d48083ed6bd99662b03eb796a9bc4968f4d9da.png ../../_images/8ba6e500073c425a7395901304b687fec13d6e7c375fa1d7300403b4da6f31d3.png ../../_images/e06f7ce3c3e6d403937f96ca2ba96c7bf1ef048fc2097a015cc3ee876148cd7f.png

You might have noticed in the first set of curves that the values for momentum \(1.1\) and \(1.3\) are not displayed. For that small \(\varepsilon=0.0001\), the error has diverged from the first update.

[out_mom[0][3].errors], [
    out_mom[0][4].errors
]  # Computation diverges from first iteration for small epsilon, high momentum.
([DeviceArray([inf, -1., -1., ..., -1., -1., -1.], dtype=float32)],
 [DeviceArray([inf, -1., -1., ..., -1., -1., -1.], dtype=float32)])

Adaptive momentum#

[Lehmann et al., 2021] propose a simple rule to update the Momentum term adaptively, after a few Sinkhorn iterations, by tracking the convergence of the algorithm to compute a momentum parameter. We test this approach with various epsilon values.

out_chg_mom, leg_chg_mom = [], []
for i, epsilon in enumerate(epsilons):
    out_chg_mom.append([out_baseline[i]])
    leg_chg_mom.append([leg_baseline[i]])
    for chg_momentum_from in [10, 20, 50, 200, 1000]:
        out_chg_mom[i].append(
            my_sinkorn(
                geometry.Geometry(cost, epsilon=epsilon),
                a,
                b,
                momentum=acceleration.Momentum(start=chg_momentum_from),
            )
        )
        leg_chg_mom[i].append(f"Change after {chg_momentum_from} it.")

As can be seen in the curves below, this seems to be a very effective and robust way to speed up the algorithm.

for i, epsilon in enumerate(epsilons):
    plot_results(
        out_chg_mom[i],
        leg_chg_mom[i],
        title=r"Adaptive Momentum, $\varepsilon$=" + str(epsilon),
        xlabel="iterations",
        ylabel="error",
    )
../../_images/ba54cff945ba9d600239a3e801c0254c27131f4cb5d2eebf604ca76010799520.png ../../_images/a53cdf6e6c4df671eb0fe4d30da5480220134da5f9ccaa0f377975acc92b38f7.png ../../_images/7640c4bc4732a010be83b3334144dfec8e51ab0cfff9a30fd0d1e2214d55ab74.png

\(\varepsilon\) decay#

It also possible to use so called Epsilon decay, which consists in starting the Sinkhorn iterations with a large \(\varepsilon\) value that is progressively decreased using a multiplicative update.

out_scaling, leg_scaling = [], []
for i, epsilon in enumerate(epsilons):
    out_scaling.append([out_baseline[i]])
    leg_scaling.append([leg_baseline[i]])
    for decay in [0.8, 0.95]:
        for init in [5, 50, 100]:
            out_scaling[i].append(
                my_sinkorn(
                    geometry.Geometry(
                        cost, epsilon=epsilon, init=init * epsilon, decay=decay
                    ),
                    a,
                    b,
                )
            )
            leg_scaling[i].append(
                rf"Decay: {decay}, Init: {init} $\varepsilon$"
            )
for i, epsilon in enumerate(epsilons):
    plot_results(
        out_scaling[i],
        leg_scaling[i],
        title=rf"Decay, $\varepsilon$=" + str(epsilon),
        xlabel="iterations",
        ylabel="error",
    )
../../_images/15e7e7716b8971ab4462f2187a891f6abd9a7d6e6e73eaf05b45aa5486c29d90.png ../../_images/2d8fe28967fa854cee077eb0edaf41ca2ca46f10d930daf13a25b1b9a0d4f479.png ../../_images/eedd3fe586dd0322becde81734df15e2cf196c65310273f10e2c5c9fb5dd657a.png

Anderson acceleration#

Using AndersonAcceleration on the Sinkhorn algorithm provides mixed results, worsening performance for smaller epsilon regularization, and slightly improving it as the regularization gets larger.

out_anderson, leg_anderson = [], []

for i, epsilon in enumerate(epsilons):
    out_anderson.append([out_baseline[i]])
    leg_anderson.append([leg_baseline[i]])
    for memory in [3, 5, 8, 15]:
        out_anderson[i].append(
            my_sinkorn(
                geometry.Geometry(cost, epsilon=epsilon),
                a,
                b,
                anderson=acceleration.AndersonAcceleration(memory=memory),
            )
        )
        leg_anderson[i].append(f"Anderson Acceleration: {memory}")
for i, epsilon in enumerate(epsilons):
    plot_results(
        out_anderson[i],
        leg_anderson[i],
        title=r"Anderson Acceleration, $\varepsilon$=" + str(epsilon),
        xlabel="iterations",
        ylabel="error",
    )
../../_images/9e199bf753a21acccf026ae07801762a4a4b022b522d6b4ced78510f68a64fc3.png ../../_images/83c4e79133ae45a0050a18429ef225a033e811dcb803f227d29d3dc837f71632.png ../../_images/2b96a48331fdcd42218d786153266ac4b6125578e7c41e8d80390c6431331978.png

Decay and momentum#

An interesting direction to accelerate convergence is to update the Momentum after the Epsilon decay schedule has converged.

out_mixed, leg_mixed = [], []
for i, epsilon in enumerate(epsilons):
    out_mixed.append([out_baseline[i]])
    leg_mixed.append([leg_baseline[i]])
    for decay, init, chg_momentum_from in [
        [0.5, 10, 10],
        [0.7, 5, 20],
        [0.9, 10, 50],
        [0.99, 2, 100],
    ]:
        out_mixed[i].append(
            my_sinkorn(
                geometry.Geometry(
                    cost, epsilon=epsilon, init=init * epsilon, decay=decay
                ),
                a,
                b,
                momentum=acceleration.Momentum(start=chg_momentum_from),
            )
        )
        leg_mixed[i].append(
            rf"Decay: {decay}, Init: {init}$\varepsilon$, ChgMom: {chg_momentum_from}"
        )
for i, epsilon in enumerate(epsilons):
    plot_results(
        out_mixed[i],
        leg_mixed[i],
        title=rf"Mixed strategy, $\varepsilon$=" + str(epsilon),
        xlabel="iterations",
        ylabel="error",
    )
../../_images/d6c4d9637875dc96a2eabbc00fd7348008ded6852f6aa4e5efdc84b0b08a2794.png ../../_images/a93f0572f4fdf2ecf3aeac68706b54218d9dc14f441977ebd675a656a58ab070.png ../../_images/f1c0d9a0ec12976c84fba86b5853a7a288dde4084043b5f92b6445360096527a.png