{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Initialization Schemes for Input Convex Neural Network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As input convex neural networks (ICNN) are notoriously difficult to train , Bunne et al. propose to use closed-form solutions between Gaussian approximations to derive relevant parameter initializations for ICNNs: given two measures $\\mu$ and $\\nu$, one can initialize ICNN parameters so that they are (initially) meaningful in the context of OT, namely that its gradient is able to approximately map source measure $\\mu$ into a target measure $\\nu$. These initializations rely on closed-form solutions available for Gaussian measures .\n", "In this notebook, we introduce the *identity* and *Gaussian approximation*-based initialization schemes, and illustrate how they can be used within the OTT library and its ICNN-based NeuralDual module." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " !pip install -q git+https://github.com/ott-jax/ott@main" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import optax\n", "import matplotlib.pyplot as plt\n", "from torch.utils.data import IterableDataset\n", "from torch.utils.data import DataLoader\n", "from ott.tools.sinkhorn_divergence import sinkhorn_divergence\n", "from ott.geometry import pointcloud\n", "from ott.solvers.nn import icnn, neuraldual" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Helper Functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us define some helper functions which we use for the subsequent analysis." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def plot_ot_map(neural_dual, source, target, inverse=False):\n", " \"\"\"Plot data and learned optimal transport map.\"\"\"\n", "\n", " def draw_arrows(a, b):\n", " plt.arrow(\n", " a[0], a[1], b[0] - a[0], b[1] - a[1], color=[0.5, 0.5, 1], alpha=0.3\n", " )\n", "\n", " if not inverse:\n", " grad_state_s = neural_dual.transport(source)\n", " else:\n", " grad_state_s = neural_dual.inverse_transport(source)\n", "\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", "\n", " if not inverse:\n", " ax.scatter(\n", " target[:, 0],\n", " target[:, 1],\n", " color=\"#A7BED3\",\n", " alpha=0.5,\n", " label=r\"$target$\",\n", " )\n", " ax.scatter(\n", " source[:, 0],\n", " source[:, 1],\n", " color=\"#1A254B\",\n", " alpha=0.5,\n", " label=r\"$source$\",\n", " )\n", " ax.scatter(\n", " grad_state_s[:, 0],\n", " grad_state_s[:, 1],\n", " color=\"#F2545B\",\n", " alpha=0.5,\n", " label=r\"$\\nabla g(source)$\",\n", " )\n", " else:\n", " ax.scatter(\n", " target[:, 0],\n", " target[:, 1],\n", " color=\"#A7BED3\",\n", " alpha=0.5,\n", " label=r\"$source$\",\n", " )\n", " ax.scatter(\n", " source[:, 0],\n", " source[:, 1],\n", " color=\"#1A254B\",\n", " alpha=0.5,\n", " label=r\"$target$\",\n", " )\n", " ax.scatter(\n", " grad_state_s[:, 0],\n", " grad_state_s[:, 1],\n", " color=\"#F2545B\",\n", " alpha=0.5,\n", " label=r\"$\\nabla f(target)$\",\n", " )\n", "\n", " plt.legend()\n", "\n", " for i in range(source.shape[0]):\n", " draw_arrows(source[i, :], grad_state_s[i, :])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_optimizer(optimizer, lr, b1, b2, eps):\n", " \"\"\"Returns a flax optimizer object based on config.\"\"\"\n", "\n", " if optimizer == \"Adam\":\n", " optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2, eps=eps)\n", " elif optimizer == \"SGD\":\n", " optimizer = optax.sgd(learning_rate=lr, momentum=None, nesterov=False)\n", " else:\n", " raise NotImplementedError(f\"Optimizer {optimizer} not supported yet!\")\n", "\n", " return optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Training and Validation Datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To test the ICNN initialization methods, we choose the NeuralDual of the OTT library as an example. Here, we aim at computing the map between two toy datasets representing both, source and target distribution. For more details on the execution of the NeuralDual module, we refer the reader to [this](https://ott-jax.readthedocs.io/en/latest/notebooks/neural_dual.html) notebook.\n", "In this tutorial, the user can choose between the datasets simple (data clustered in one center), circle (two-dimensional Gaussians arranged on a circle), square_five (two-dimensional Gaussians on a square with one Gaussian in the center), and square_four (two-dimensional Gaussians in the corners of a rectangle)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class ToyDataset(IterableDataset):\n", " def __init__(self, name):\n", " self.name = name\n", "\n", " def __iter__(self):\n", " return self.create_sample_generators()\n", "\n", " def create_sample_generators(self, scale=5.0, variance=0.5):\n", " # given name of dataset, select centers\n", " if self.name == \"simple\":\n", " centers = np.array([0, 0])\n", "\n", " elif self.name == \"circle\":\n", " centers = np.array(\n", " [\n", " (1, 0),\n", " (-1, 0),\n", " (0, 1),\n", " (0, -1),\n", " (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " ]\n", " )\n", "\n", " elif self.name == \"square_five\":\n", " centers = np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]])\n", "\n", " elif self.name == \"square_four\":\n", " centers = np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])\n", "\n", " else:\n", " raise NotImplementedError()\n", "\n", " # create generator which randomly picks center and adds noise\n", " centers = scale * centers\n", " while True:\n", " center = centers[np.random.choice(len(centers))]\n", " point = center + variance**2 * np.random.randn(2)\n", "\n", " yield point\n", "\n", "\n", "def load_toy_data(\n", " name_source: str,\n", " name_target: str,\n", " batch_size: int = 1024,\n", " valid_batch_size: int = 1024,\n", "):\n", " dataloaders = (\n", " iter(DataLoader(ToyDataset(name_source), batch_size=batch_size)),\n", " iter(DataLoader(ToyDataset(name_target), batch_size=batch_size)),\n", " iter(DataLoader(ToyDataset(name_source), batch_size=valid_batch_size)),\n", " iter(DataLoader(ToyDataset(name_target), batch_size=valid_batch_size)),\n", " )\n", " input_dim = 2\n", " return dataloaders, input_dim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Experimental Setup " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to solve the neural dual, we need to define our dataloaders. The only requirement is that the corresponding source and target train and validation datasets are *iterators*." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "(dataloader_source, dataloader_target, _, _), input_dim = load_toy_data(\n", " \"simple\", \"circle\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To visualize the initialization schemes, let's sample data from the source and target distribution." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "data_source = next(dataloader_source).numpy()\n", "data_target = next(dataloader_target).numpy()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# initialize optimizers\n", "optimizer_f = get_optimizer(\"Adam\", lr=0.0001, b1=0.5, b2=0.9, eps=0.00000001)\n", "optimizer_g = get_optimizer(\"Adam\", lr=0.0001, b1=0.5, b2=0.9, eps=0.00000001)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Identity Initialization Method" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define the architectures parameterizing the dual potentials $f$ and $g$. These need to be parameterized by ICNNs. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can solve the NeuralDual using approximations to this positivity constraint (via weight clipping and a weight penalization). For this, set positive weights to True in both the ICNN architecture and NeuralDualSolver configuration. For more details on how to customize the ICNN architectures, we refer you to the documentation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first explore the identity-initialization method. This initialization method is the default choice of the current ICNN and data independent, thus no further arguments need to be passed to the ICNN architecture." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# initialize models using identity initialization (default)\n", "neural_f = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)\n", "neural_g = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6149d916e1ca484a94c4f5b218825e1d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "neural_dual_solver = NeuralDualSolver(\n", " input_dim, neural_f, neural_g, optimizer_f, optimizer_g, num_train_iters=0\n", ")\n", "neural_dual = neural_dual_solver(\n", " dataloader_source, dataloader_target, dataloader_source, dataloader_target\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can plot the corresponding transport from source to target using the gradient of the learning potential NeuralDual.g, i.e., $\\nabla g(\\text{source})$, or from target to source via the gradient of the learning potential NeuralDual.f, i.e., $\\nabla f(\\text{target})$." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "