ICNN Dual Solver#

In this tutorial, we explore how to learn the solution of the Kantorovich dual based on parameterizing the two dual potentials \(f\) and \(g\) with two input convex neural networks (ICNN), a method developed by Makkuva et al. (2020). For more insights on the approach itself, we refer the user to the original publication. Given dataloaders containing samples of the source and the target distribution, OTT’s NeuralDualSolver finds the pair of optimal potentials \(f\) and \(g\) to solve the corresponding dual of the optimal transport problem. Once a solution has been found, this can be used to transport unseen source data samples to its target distribution (or vice-versa) or compute the corresponding distance between new source and target distribution.

[1]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
import matplotlib.pyplot as plt
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from ott.geometry import pointcloud
from ott.core.neuraldual import NeuralDualSolver
from ott.core import icnn

Helper Functions#

Let us define some helper functions which we use for the subsequent analysis.

[2]:
def plot_ot_map(neural_dual, source, target, inverse=False):
    """Plot data and learned optimal transport map."""

    def draw_arrows(a, b):
        plt.arrow(a[0], a[1], b[0] - a[0], b[1] - a[1],
                  color=[0.5, 0.5, 1], alpha=0.3)

    if not inverse:
      grad_state_s = neural_dual.transport(source)
    else:
      grad_state_s = neural_dual.inverse_transport(source)

    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.scatter(target[:, 0], target[:, 1], color='#A7BED3',
               alpha=0.5, label=r'$target$')
    ax.scatter(source[:, 0], source[:, 1], color='#1A254B',
               alpha=0.5, label=r'$source$')
    if not inverse:
        ax.scatter(grad_state_s[:, 0], grad_state_s[:, 1], color='#F2545B',
               alpha=0.5, label=r'$\nabla g(source)$')
    else:
        ax.scatter(grad_state_s[:, 0], grad_state_s[:, 1], color='#F2545B',
                   alpha=0.5, label=r'$\nabla f(target)$')

    plt.legend()

    for i in range(source.shape[0]):
        draw_arrows(source[i, :], grad_state_s[i, :])
[3]:
def get_optimizer(optimizer, lr, b1, b2, eps):
  """Returns a flax optimizer object based on `config`."""

  if optimizer == 'Adam':
      optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2, eps=eps)
  elif optimizer == 'SGD':
      optimizer = optax.sgd(learning_rate=lr, momentum=None, nesterov=False)
  else:
      raise NotImplementedError(
          f'Optimizer {optimizer} not supported yet!')

  return optimizer
[4]:
@jax.jit
def sinkhorn_loss(x, y, epsilon=0.1, power=2.0):
    """Computes transport between (x, a) and (y, b) via Sinkhorn algorithm."""
    a = jnp.ones(len(x)) / len(x)
    b = jnp.ones(len(y)) / len(y)

    sdiv = sinkhorn_divergence(pointcloud.PointCloud, x, y, power=power,
                               epsilon=epsilon, a=a, b=b)
    return sdiv.divergence

Setup Training and Validation Datasets#

We apply the NeuralDual to compute the transport between toy datasets. In this tutorial, the user can choose between the datasets simple (data clustered in one center), circle (two-dimensional Gaussians arranged on a circle), square_five (two-dimensional Gaussians on a square with one Gaussian in the center), and square_four (two-dimensional Gaussians in the corners of a rectangle).

[5]:
class ToyDataset(IterableDataset):
    def __init__(self, name):
        self.name = name

    def __iter__(self):
        return self.create_sample_generators()

    def create_sample_generators(self, scale=5.0, variance=0.5):
        # given name of dataset, select centers
        if self.name == "simple":
            centers = np.array([0, 0])

        elif self.name == "circle":
            centers = np.array(
                [
                    (1, 0),
                    (-1, 0),
                    (0, 1),
                    (0, -1),
                    (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                ]
            )

        elif self.name == "square_five":
            centers = np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]])

        elif self.name == "square_four":
            centers = np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])

        else:
            raise NotImplementedError()

        # create generator which randomly picks center and adds noise
        centers = scale * centers
        while True:
            center = centers[np.random.choice(len(centers))]
            point = center + variance**2 * np.random.randn(2)

            yield point


def load_toy_data(name_source: str,
                  name_target: str,
                  batch_size: int = 1024,
                  valid_batch_size: int = 1024):
    dataloaders = (
      iter(DataLoader(ToyDataset(name_source), batch_size=batch_size)),
      iter(DataLoader(ToyDataset(name_target), batch_size=batch_size)),
      iter(DataLoader(ToyDataset(name_source), batch_size=valid_batch_size)),
      iter(DataLoader(ToyDataset(name_target), batch_size=valid_batch_size)),
    )
    input_dim = 2
    return dataloaders, input_dim

Solve Neural Dual#

In order to solve the neural dual, we need to define our dataloaders. The only requirement is that the corresponding source and target train and validation datasets are iterators.

[6]:
(dataloader_source, dataloader_target, _, _), input_dim = load_toy_data('simple', 'circle')

Next, we define the architectures parameterizing the dual potentials \(f\) and \(g\). These need to be parameterized by ICNNs. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can solve the NeuralDual using approximations to this positivity constraint (via weight clipping and a weight penalization). For this, set positive weights to True in both the ICNN architecture and NeuralDualSolver configuration. For more details on how to customize the ICNN architectures, we refer you to the documentation.

[7]:
# initialize models
neural_f = icnn.ICNN(dim_hidden=[64, 64, 64, 64])
neural_g = icnn.ICNN(dim_hidden=[64, 64, 64, 64])

# initialize optimizers
optimizer_f = get_optimizer('Adam', lr=0.0001, b1=0.5, b2=0.9, eps=0.00000001)
optimizer_g = get_optimizer('Adam', lr=0.0001, b1=0.5, b2=0.9, eps=0.00000001)

We then initialize the NeuralDualSolver by passing the two ICNN models parameterizing \(f\) and \(g\), as well as by specifying the input dimensions of the data and the number of training iterations to execute. Once the NeuralDualSolver is initialized, we can obtain the NeuralDual by passing the corresponding dataloaders to it, which will subsequently return the optimal NeuralDual for the problem. As here our training and validation datasets do not differ, we pass (dataloader_source, dataloader_target) for both training and validation steps. For more details on how to configer the NeuralDualSolver, we refer you to the documentation.

[8]:
neural_dual_solver = NeuralDualSolver(
    input_dim, neural_f, neural_g, optimizer_f, optimizer_g, num_train_iters=5000)
neural_dual = neural_dual_solver(
    dataloader_source, dataloader_target, dataloader_source, dataloader_target)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
100%|███████████████████████████████████████| 5000/5000 [36:32<00:00,  2.28it/s]

Evaluate Neural Dual#

After training has completed successfully, we can evaluate the NeuralDual on unseen incoming data. We first sample a new batch from the source and target distribution.

[9]:
data_source = next(dataloader_source).numpy()
data_target = next(dataloader_target).numpy()

Now we can plot the corresponding transport from source to target using the gradient of the learning potential NeuralDual.g, i.e., \(\nabla g(\text{source})\), or from target to source via the gradient of the learning potential NeuralDual.f, i.e., \(\nabla f(\text{target})\).

[10]:
plot_ot_map(neural_dual, data_source, data_target, inverse=False)
../_images/notebooks_neural_dual_22_0.png
[11]:
plot_ot_map(neural_dual, data_target, data_source, inverse=True)
../_images/notebooks_neural_dual_23_0.png

We further test, how close the predicted samples are to the sampled data.

First for potential \(g\), transporting source to target samples. Ideally the resulting Sinkhorn distance is close to 0.

[12]:
pred_target = neural_dual.transport(data_source)
print(f'Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_target, data_target)}')
Sinkhorn distance between predictions and data samples: 1.6507648229599

Then for potential \(f\), transporting target to source samples. Again, the resulting Sinkhorn distance needs to be close to 0.

[13]:
pred_source = neural_dual.inverse_transport(data_target)
print(f'Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_source, data_source)}')
Sinkhorn distance between predictions and data samples: 0.07880353927612305

Besides computing the transport and mapping source to target samples or vice versa, we can also compute the overall distance between new source and target samples.

[14]:
neural_dual_dist = neural_dual.distance(data_source, data_target)
print(f'Neural dual distance between source and target data: {neural_dual_dist}')
Neural dual distance between source and target data: 22.147186279296875

Which compares to the primal Sinkhorn distance in the following.

[15]:
sinkhorn_dist = sinkhorn_loss(data_source, data_target)
print(f'Sinkhorn distance between source and target data: {sinkhorn_dist}')
Sinkhorn distance between source and target data: 22.19419288635254
[ ]: