{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ICNN Dual Solver " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we explore how to learn the solution of the Kantorovich dual based on parameterizing the two dual potentials $f$ and $g$ with two input convex neural networks (ICNN) , a method developed by Makkuva et al.. For more insights on the approach itself, we refer the user to the original publication.\n", "Given dataloaders containing samples of the *source* and the *target* distribution, OTT's NeuralDualSolver finds the pair of optimal potentials $f$ and $g$ to solve the corresponding dual of the optimal transport problem. Once a solution has been found, this can be used to transport unseen source data samples to its target distribution (or vice-versa) or compute the corresponding distance between new source and target distribution." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " !pip install -q git+https://github.com/ott-jax/ott@main" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import optax\n", "import matplotlib.pyplot as plt\n", "from torch.utils.data import IterableDataset\n", "from torch.utils.data import DataLoader\n", "from ott.tools.sinkhorn_divergence import sinkhorn_divergence\n", "from ott.geometry import pointcloud\n", "from ott.core.neuraldual import NeuralDualSolver\n", "from ott.core import icnn" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Helper Functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us define some helper functions which we use for the subsequent analysis." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def plot_ot_map(neural_dual, source, target, inverse=False):\n", " \"\"\"Plot data and learned optimal transport map.\"\"\"\n", "\n", " def draw_arrows(a, b):\n", " plt.arrow(\n", " a[0], a[1], b[0] - a[0], b[1] - a[1], color=[0.5, 0.5, 1], alpha=0.3\n", " )\n", "\n", " if not inverse:\n", " grad_state_s = neural_dual.transport(source)\n", " else:\n", " grad_state_s = neural_dual.inverse_transport(source)\n", "\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", "\n", " if not inverse:\n", " ax.scatter(\n", " target[:, 0],\n", " target[:, 1],\n", " color=\"#A7BED3\",\n", " alpha=0.5,\n", " label=r\"$target$\",\n", " )\n", " ax.scatter(\n", " source[:, 0],\n", " source[:, 1],\n", " color=\"#1A254B\",\n", " alpha=0.5,\n", " label=r\"$source$\",\n", " )\n", " ax.scatter(\n", " grad_state_s[:, 0],\n", " grad_state_s[:, 1],\n", " color=\"#F2545B\",\n", " alpha=0.5,\n", " label=r\"$\\nabla g(source)$\",\n", " )\n", " else:\n", " ax.scatter(\n", " target[:, 0],\n", " target[:, 1],\n", " color=\"#A7BED3\",\n", " alpha=0.5,\n", " label=r\"$source$\",\n", " )\n", " ax.scatter(\n", " source[:, 0],\n", " source[:, 1],\n", " color=\"#1A254B\",\n", " alpha=0.5,\n", " label=r\"$target$\",\n", " )\n", " ax.scatter(\n", " grad_state_s[:, 0],\n", " grad_state_s[:, 1],\n", " color=\"#F2545B\",\n", " alpha=0.5,\n", " label=r\"$\\nabla f(target)$\",\n", " )\n", "\n", " plt.legend()\n", "\n", " for i in range(source.shape[0]):\n", " draw_arrows(source[i, :], grad_state_s[i, :])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_optimizer(optimizer, lr, b1, b2, eps):\n", " \"\"\"Returns a flax optimizer object based on config.\"\"\"\n", "\n", " if optimizer == \"Adam\":\n", " optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2, eps=eps)\n", " elif optimizer == \"SGD\":\n", " optimizer = optax.sgd(learning_rate=lr, momentum=None, nesterov=False)\n", " else:\n", " raise NotImplementedError(f\"Optimizer {optimizer} not supported yet!\")\n", "\n", " return optimizer" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def sinkhorn_loss(x, y, epsilon=0.1, power=2.0):\n", " \"\"\"Computes transport between (x, a) and (y, b) via Sinkhorn algorithm.\"\"\"\n", " a = jnp.ones(len(x)) / len(x)\n", " b = jnp.ones(len(y)) / len(y)\n", "\n", " sdiv = sinkhorn_divergence(\n", " pointcloud.PointCloud, x, y, power=power, epsilon=epsilon, a=a, b=b\n", " )\n", " return sdiv.divergence" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Training and Validation Datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We apply the NeuralDual to compute the transport between toy datasets. In this tutorial, the user can choose between the datasets simple (data clustered in one center), circle (two-dimensional Gaussians arranged on a circle), square_five (two-dimensional Gaussians on a square with one Gaussian in the center), and square_four (two-dimensional Gaussians in the corners of a rectangle)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class ToyDataset(IterableDataset):\n", " def __init__(self, name):\n", " self.name = name\n", "\n", " def __iter__(self):\n", " return self.create_sample_generators()\n", "\n", " def create_sample_generators(self, scale=5.0, variance=0.5):\n", " # given name of dataset, select centers\n", " if self.name == \"simple\":\n", " centers = np.array([0, 0])\n", "\n", " elif self.name == \"circle\":\n", " centers = np.array(\n", " [\n", " (1, 0),\n", " (-1, 0),\n", " (0, 1),\n", " (0, -1),\n", " (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " ]\n", " )\n", "\n", " elif self.name == \"square_five\":\n", " centers = np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]])\n", "\n", " elif self.name == \"square_four\":\n", " centers = np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])\n", "\n", " else:\n", " raise NotImplementedError()\n", "\n", " # create generator which randomly picks center and adds noise\n", " centers = scale * centers\n", " while True:\n", " center = centers[np.random.choice(len(centers))]\n", " point = center + variance**2 * np.random.randn(2)\n", "\n", " yield point\n", "\n", "\n", "def load_toy_data(\n", " name_source: str,\n", " name_target: str,\n", " batch_size: int = 1024,\n", " valid_batch_size: int = 1024,\n", "):\n", " dataloaders = (\n", " iter(DataLoader(ToyDataset(name_source), batch_size=batch_size)),\n", " iter(DataLoader(ToyDataset(name_target), batch_size=batch_size)),\n", " iter(DataLoader(ToyDataset(name_source), batch_size=valid_batch_size)),\n", " iter(DataLoader(ToyDataset(name_target), batch_size=valid_batch_size)),\n", " )\n", " input_dim = 2\n", " return dataloaders, input_dim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solve Neural Dual" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to solve the neural dual, we need to define our dataloaders. The only requirement is that the corresponding source and target train and validation datasets are *iterators*." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "(dataloader_source, dataloader_target, _, _), input_dim = load_toy_data(\n", " \"square_five\", \"square_four\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define the architectures parameterizing the dual potentials $f$ and $g$. These need to be parameterized by ICNNs. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can solve the NeuralDual using approximations to this positivity constraint (via weight clipping and a weight penalization). For this, set positive weights to True in both the ICNN architecture and NeuralDualSolver configuration. For more details on how to customize the ICNN architectures, we refer you to the documentation." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# initialize models\n", "neural_f = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)\n", "neural_g = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)\n", "\n", "# initialize optimizers\n", "optimizer_f = get_optimizer(\"Adam\", lr=0.001, b1=0.5, b2=0.9, eps=1e-8)\n", "optimizer_g = get_optimizer(\"Adam\", lr=0.001, b1=0.5, b2=0.9, eps=1e-8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then initialize the NeuralDualSolver by passing the two ICNN models parameterizing $f$ and $g$, as well as by specifying the input dimensions of the data and the number of training iterations to execute. Once the NeuralDualSolver is initialized, we can obtain the NeuralDual by passing the corresponding dataloaders to it, which will subsequently return the optimal NeuralDual for the problem. As here our training and validation datasets do not differ, we pass (dataloader_source, dataloader_target) for both training and validation steps. For more details on how to configer the NeuralDualSolver, we refer you to the documentation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Execution of the following cell might take up to 15 minutes per 5000 iterations (depending on your system and the number of training iterations." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31b3dff5bd2840b0b358f91fdb2b117b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/15000 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_ot_map(neural_dual, data_source, data_target, inverse=False)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "