diff --git a/README.md b/README.md
index 52b914f..4ac4bb5 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,61 @@
-# causalvae
+
+## Deep Causal Varitional Autoencoder
+
+To train a supervised variational autoencoder using Deepmind's [dSprites](https://github.com/deepmind/dsprites-dataset) dataset.
+
+dSprites is a dataset of sprites, which are 2D shapes procedurally generated from 5 ground truth independent "factors." These factors are color, shape, scale, rotation, x and y positions of a sprite.
+
+All possible combinations of these variables are present exactly once, generating N = 737280 total images.
+
+Factors and their values:
+
+* Shape: 3 values {square, ellipse, heart}
+* Scale: 6 values linearly spaced in (0.5, 1)
+* Orientation: 40 values in (0, 2pi)
+* Position X: 32 values in (0, 1)
+* Position Y: 32 values in (0, 1)
+
+There is a sixth factor for color, but it is white for every image in this dataset.
+
+The purpose of this dataset was to evaluate the ability of disentanglement methods. In these methods, you treat these factors as latent and then try to "disentangle" them in the latent representation.
+
+However, in this project, these factors are not treated as latent, but are included as labels in the model training. Further, a causal story is invented that relates these factors and the images in a DAG
+
+
+
+Structural causal model is of the form:
+
+
+
+The image variable will be a 64 x 64 array. The noise term for the image variable will be the traditional Gaussian random variable. The structural assignment *g* for the image variable will be the decoder.
+
+
+## Work:
+* Built a Structural causal model that articulates a causal story relating shape, orientation, scale, X, Y, and the data.
+* Resampled the dataset to get a new dataset with an empirical distribution that is faithful to the DAG and is entailed by the SCM
+* To implement a causal VAE using [Pyro](http://pyro.ai/) by extending the primitive version of VAE. The VAE is fully supervised.
+* Finally used the trained model to answer some counterfactual queries, for example, "given this image of a heart with this orientation, position, and scale, what would it have looked like if it were a square?"
+
+## Optimization:
+* The code is made compatible for GPU for faster processing.
+* The learned weights are saved to avoid training frequently to enhance development efficiency.
+
+## Results
+* Achieved good reconstruction accuracy using vanilla VAE -
+
+
+* Trained VAE and made sure it recognises changes in the latent dimensions(Manually changed latent variables before training) -
+
+
+* Built structural causal model and verfied for reconstruction accuracy -
+
+
+* Counterfactual queries (1) - Intervention on shape - Given a oval and certain (x,y) co-ordinates and orientation, how would it look it was a sqaure?
+
+
+* Counterfactual queries (1) - Intervention on shape, position of (x,y) -
+
+
+## Applications
+* DeepFakes :[Structured Disentangled Representations](https://arxiv.org/pdf/1804.02086.pdf)
+
diff --git a/causal_vae_dsprites.ipynb b/causal_vae_dsprites.ipynb
new file mode 100644
index 0000000..ec540fe
--- /dev/null
+++ b/causal_vae_dsprites.ipynb
@@ -0,0 +1,1851 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ },
+ "colab": {
+ "name": "causal_vae_dsprites.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": [],
+ "machine_shape": "hm",
+ "include_colab_link": true
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zSOMAHldq2xR",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Deep Causal Variational Inference\n",
+ "\n",
+ "### Introduction:\n",
+ "To train a supervised variational autoencoder using Deepmind's [dSprites](https://github.com/deepmind/dsprites-dataset) dataset.\n",
+ "\n",
+ "dSprites is a dataset of sprites, which are 2D shapes procedurally generated from 5 ground truth independent \"factors.\" These factors are color, shape, scale, rotation, x and y positions of a sprite.\n",
+ "\n",
+ "All possible combinations of these variables are present exactly once, generating N = 737280 total images.\n",
+ "\n",
+ "Factors and their values:\n",
+ "\n",
+ "* Shape: 3 values {square, ellipse, heart}\n",
+ "* Scale: 6 values linearly spaced in (0.5, 1)\n",
+ "* Orientation: 40 values in (0, 2$\\pi$)\n",
+ "* Position X: 32 values in (0, 1)\n",
+ "* Position Y: 32 values in (0, 1)\n",
+ "\n",
+ "\n",
+ "Further, the objective of any generative model is essentially to capture underlying data generative factors, the disentangled representation would mean a single latent unit being sensitive to variations in single generative factors\n",
+ "\n",
+ "\n",
+ "### Goal:\n",
+ "To include the latent factors as labels in the training and to invent a causal story that relates these factors and the images in a DAG.\n",
+ "\n",
+ "Reference \n",
+ "\n",
+ "[Structured Disentangled Representation](https://arxiv.org/pdf/1804.02086.pdf)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "CMhHGoIwCwkN",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "#Install dependencies\n",
+ "!pip3 install pyro-ppl\n",
+ "!pip3 install torch torchvision\n",
+ "!pip3 install pydrive --upgrade\n",
+ "!pip3 install tqdm"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "uNa3MO8GCqWg",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Load necessary libraries\n",
+ "from matplotlib import pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "\n",
+ "import os\n",
+ "from collections import defaultdict\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "import pyro\n",
+ "import pyro.distributions as dist\n",
+ "from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, EmpiricalMarginal\n",
+ "from pyro.optim import Adam, SGD\n",
+ "import torch.distributions.constraints as constraints\n",
+ "\n",
+ "# Change figure aesthetics\n",
+ "%matplotlib inline\n",
+ "sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})\n",
+ "\n",
+ "from ipywidgets import interact, interactive, fixed, interact_manual\n",
+ "import ipywidgets as widgets\n",
+ "\n",
+ "#to utilize GPU capabilities\n",
+ "USE_CUDA = True\n",
+ "\n",
+ "pyro.enable_validation(True)\n",
+ "pyro.distributions.enable_validation(False)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "nkXRHmSqDFTy",
+ "colab_type": "code",
+ "outputId": "40eb56b2-2dcd-42d2-942f-ccf0e4c08982",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 124
+ }
+ },
+ "source": [
+ "# Mount Google drive to load data\n",
+ "from google.colab import drive\n",
+ "drive.mount('/content/gdrive')"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code\n",
+ "\n",
+ "Enter your authorization code:\n",
+ "··········\n",
+ "Mounted at /content/gdrive\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "KRCLhKI7d2r1",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Mount G drive to access files\n",
+ "from pydrive.auth import GoogleAuth\n",
+ "from pydrive.drive import GoogleDrive\n",
+ "from google.colab import auth\n",
+ "from oauth2client.client import GoogleCredentials\n",
+ "\n",
+ "auth.authenticate_user()\n",
+ "gauth = GoogleAuth()\n",
+ "gauth.credentials = GoogleCredentials.get_application_default()\n",
+ "drive = GoogleDrive(gauth)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "cqyc4SqyDLIq",
+ "colab_type": "code",
+ "outputId": "3d558d65-597e-47dc-fe0f-7a9f9dad6945",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 262
+ }
+ },
+ "source": [
+ "# Hack to get all available GPU ram.\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "tf.test.gpu_device_name()\n",
+ "\n",
+ "!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi\n",
+ "!pip install gputil\n",
+ "!pip install psutil\n",
+ "!pip install humanize\n",
+ "import psutil\n",
+ "import humanize\n",
+ "import os\n",
+ "import GPUtil as GPU\n",
+ "GPUs = GPU.getGPUs()\n",
+ "# XXX: only one GPU on Colab and isn’t guaranteed\n",
+ "gpu = GPUs[0]\n",
+ "def printm():\n",
+ " process = psutil.Process(os.getpid())\n",
+ " print(\"Gen RAM Free: \" + humanize.naturalsize( psutil.virtual_memory().available ), \" | Proc size: \" + humanize.naturalsize( process.memory_info().rss))\n",
+ " print(\"GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB\".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))\n",
+ "printm()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Collecting gputil\n",
+ " Downloading https://files.pythonhosted.org/packages/ed/0e/5c61eedde9f6c87713e89d794f01e378cfd9565847d4576fa627d758c554/GPUtil-1.4.0.tar.gz\n",
+ "Building wheels for collected packages: gputil\n",
+ " Building wheel for gputil (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for gputil: filename=GPUtil-1.4.0-cp36-none-any.whl size=7411 sha256=0ecd5811e41f36cf83f484d14611cf7cf4232041ade307328b392bc99b263545\n",
+ " Stored in directory: /root/.cache/pip/wheels/3d/77/07/80562de4bb0786e5ea186911a2c831fdd0018bda69beab71fd\n",
+ "Successfully built gputil\n",
+ "Installing collected packages: gputil\n",
+ "Successfully installed gputil-1.4.0\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.6/dist-packages (5.4.8)\n",
+ "Requirement already satisfied: humanize in /usr/local/lib/python3.6/dist-packages (0.5.1)\n",
+ "Gen RAM Free: 25.2 GB | Proc size: 2.2 GB\n",
+ "GPU RAM Free: 11111MB | Used: 330MB | Util 3% | Total 11441MB\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "nSkXleG1CqWj",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class Encoder(nn.Module):\n",
+ "\t\"\"\"\n",
+ " MLPs (multi-layered perceptrons or simple feed-forward networks)\n",
+ " where the provided activation parameter is used on every linear layer except\n",
+ " for the output layer where we use the provided output_activation parameter\n",
+ "\t\"\"\"\n",
+ "\tdef __init__(self, image_dim, label_dim, z_dim):\n",
+ "\t\tsuper(Encoder, self).__init__()\n",
+ "\t\t#setup image and label dimensions from the dataset\n",
+ "\t\tself.image_dim = image_dim\n",
+ "\t\tself.label_dim = label_dim\n",
+ "\t\tself.z_dim = z_dim\n",
+ "\t\t# setup the three linear transformations used\n",
+ "\t\tself.fc1 = nn.Linear(self.image_dim+self.label_dim, 1000)\n",
+ "\t\tself.fc2 = nn.Linear(1000, 1000)\n",
+ "\t\tself.fc31 = nn.Linear(1000, z_dim) # mu values\n",
+ "\t\tself.fc32 = nn.Linear(1000, z_dim) # sigma values\n",
+ "\t\t# setup the non-linearities\n",
+ "\t\tself.softplus = nn.Softplus()\n",
+ "\n",
+ "\tdef forward(self, xs, ys):\n",
+ "\t\txs = xs.reshape(-1, self.image_dim)\n",
+ "\t\t#now concatenate the image and label\n",
+ "\t\tinputs = torch.cat((xs,ys), -1)\n",
+ "\t\t# then compute the hidden units\n",
+ "\t\thidden1 = self.softplus(self.fc1(inputs))\n",
+ "\t\thidden2 = self.softplus(self.fc2(hidden1))\n",
+ "\t\t# then return a mean vector and a (positive) square root covariance\n",
+ "\t\t# each of size batch_size x z_dim\n",
+ "\t\tz_loc = self.fc31(hidden2)\n",
+ "\t\tz_scale = torch.exp(self.fc32(hidden2))\n",
+ "\t\treturn z_loc, z_scale\n",
+ "\n",
+ "\n",
+ "class Decoder(nn.Module):\n",
+ "\tdef __init__(self, image_dim, label_dim, z_dim):\n",
+ "\t\tsuper(Decoder, self).__init__()\n",
+ "\t\t# setup the two linear transformations used\n",
+ "\t\thidden_dim = 1000\n",
+ "\t\tself.fc1 = nn.Linear(z_dim+label_dim, hidden_dim)\n",
+ "\t\tself.fc2 = nn.Linear(hidden_dim, hidden_dim)\n",
+ "\t\tself.fc3 = nn.Linear(hidden_dim, hidden_dim)\n",
+ "\t\tself.fc4 = nn.Linear(hidden_dim, image_dim)\n",
+ "\t\t# setup the non-linearities\n",
+ "\t\tself.softplus = nn.Softplus()\n",
+ "\t\tself.sigmoid = nn.Sigmoid()\n",
+ "\n",
+ "\tdef forward(self, zs, ys):\n",
+ "\t\tinputs = torch.cat((zs, ys),-1)\n",
+ "\t\t# then compute the hidden units\n",
+ "\t\thidden1 = self.softplus(self.fc1(inputs))\n",
+ "\t\thidden2 = self.softplus(self.fc2(hidden1))\n",
+ "\t\thidden3 = self.softplus(self.fc3(hidden2))\n",
+ "\t\t# return the parameter for the output Bernoulli\n",
+ "\t\t# each is of size batch_size x 784\n",
+ "\t\tloc_img = self.sigmoid(self.fc4(hidden3))\n",
+ "\t\treturn loc_img"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Yqj05Mv8CqWm",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class CVAE(nn.Module):\n",
+ "\t\"\"\"\n",
+ "\tThis class encapsulates the parameters (neural networks) and models & guides \n",
+ "\tneeded to train a supervised variational auto-encoder \n",
+ "\t\"\"\"\n",
+ "\tdef __init__(self, config_enum=None, use_cuda=False, aux_loss_multiplier=None):\n",
+ " \n",
+ "\t\tsuper(CVAE, self).__init__()\n",
+ "\t\tself.image_dim = 64**2\n",
+ "\t\tself.label_shape = np.array((1,3,6,40,32,32))\n",
+ "\t\tself.label_names = np.array(('color', 'shape', 'scale', 'orientation', 'posX', 'posY'))\n",
+ "\t\tself.label_dim = np.sum(self.label_shape)\n",
+ "\t\tself.z_dim = 50 \n",
+ "\t\tself.allow_broadcast = config_enum == 'parallel'\n",
+ "\t\tself.use_cuda = use_cuda\n",
+ "\t\tself.aux_loss_multiplier = aux_loss_multiplier\n",
+ "\t # define and instantiate the neural networks representing\n",
+ "\t # the paramters of various distributions in the model\n",
+ "\t\tself.setup_networks()\n",
+ "\n",
+ "\tdef setup_networks(self):\n",
+ "\t\t\"\"\"\n",
+ "\t\tSetup and initialize Encoder and decoder units\n",
+ "\t\t\"\"\"\n",
+ "\t\tself.encoder = Encoder(self.image_dim, self.label_dim, self.z_dim)\n",
+ "\t\tself.decoder = Decoder(self.image_dim, self.label_dim, self.z_dim)\n",
+ "\t\t# using GPUs for faster training of the networks\n",
+ "\t\tif self.use_cuda:\n",
+ "\t\t self.cuda()\n",
+ "\n",
+ "\tdef model(self, xs, ys):\n",
+ "\t\tpyro.module(\"cvae\", self)\n",
+ "\t\tbatch_size = xs.size(0)\n",
+ "\t\toptions = dict(dtype=xs.dtype, device=xs.device)\n",
+ "\t\twith pyro.plate(\"data\"):\n",
+ "\t\t\tprior_loc = torch.zeros(batch_size, self.z_dim, **options)\n",
+ "\t\t\tprior_scale = torch.ones(batch_size, self.z_dim, **options)\n",
+ "\t\t\tzs = pyro.sample(\"z\", dist.Normal(prior_loc, prior_scale).to_event(1))\n",
+ "\t\t\t# if the label y (which digit to write) is supervised, sample from the\n",
+ "\t\t\t# constant prior, otherwise, observe the value (i.e. score it against the constant prior)\n",
+ "\t\t\tloc = self.decoder.forward(zs, self.remap_y(ys))\n",
+ "\t\t\tpyro.sample(\"x\", dist.Bernoulli(loc).to_event(1), obs=xs)\n",
+ "\t\t \t# return the loc so we can visualize it later\n",
+ "\t\t\treturn loc\n",
+ "\n",
+ "\tdef guide(self, xs, ys):\n",
+ "\t\twith pyro.plate(\"data\"):\n",
+ "\t\t\t# sample (and score) the latent handwriting-style with the variational\n",
+ "\t\t\t# distribution q(z|x) = normal(loc(x),scale(x))\n",
+ "\t\t\tloc, scale = self.encoder.forward(xs, self.remap_y(ys))\n",
+ "\t\t\tpyro.sample(\"z\", dist.Normal(loc, scale).to_event(1))\n",
+ "\n",
+ "\tdef remap_y(self, ys):\n",
+ "\t\tnew_ys = []\n",
+ "\t\toptions = dict(dtype=ys.dtype, device=ys.device)\n",
+ "\t\tfor i, label_length in enumerate(self.label_shape):\n",
+ "\t\t prior = torch.ones(ys.size(0), label_length, **options) / (1.0 * label_length)\n",
+ "\t\t new_ys.append(pyro.sample(\"y_%s\" % self.label_names[i], dist.OneHotCategorical(prior), \n",
+ "\t\t obs=torch.nn.functional.one_hot(ys[:,i].to(torch.int64), int(label_length))))\n",
+ "\t\tnew_ys = torch.cat(new_ys, -1)\n",
+ "\t\treturn new_ys.to(torch.float32)\n",
+ "\n",
+ "\tdef reconstruct_image(self, xs, ys):\n",
+ "\t\t# backward\n",
+ "\t\tsim_z_loc, sim_z_scale = self.encoder.forward(xs, self.remap_y(ys))\n",
+ "\t\tzs = dist.Normal(sim_z_loc, sim_z_scale).to_event(1).sample()\n",
+ "\t\t# forward\n",
+ "\t\tloc = self.decoder.forward(zs, self.remap_y(ys))\n",
+ "\t\treturn dist.Bernoulli(loc).to_event(1).sample()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1mmMe3D8CqWo",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def setup_data_loaders(train_x, test_x, train_y, test_y, batch_size=128, use_cuda=False):\n",
+ "\ttrain_dset = torch.utils.data.TensorDataset(\n",
+ "\t torch.from_numpy(train_x.astype(np.float32)).reshape(-1, 4096),\n",
+ "\t torch.from_numpy(train_y.astype(np.float32))\n",
+ "\t)\n",
+ "\ttest_dset = torch.utils.data.TensorDataset(\n",
+ "\t torch.from_numpy(test_x.astype(np.float32)).reshape(-1, 4096),\n",
+ "\t torch.from_numpy(test_y.astype(np.float32))\n",
+ "\t) \n",
+ "\tkwargs = {'num_workers': 1, 'pin_memory': use_cuda}\n",
+ "\ttrain_loader = torch.utils.data.DataLoader(\n",
+ "\t dataset=train_dset, batch_size=batch_size, shuffle=False, **kwargs\n",
+ "\t)\n",
+ "\n",
+ "\ttest_loader = torch.utils.data.DataLoader(\n",
+ "\t dataset=test_dset, batch_size=batch_size, shuffle=False, **kwargs\n",
+ "\t)\n",
+ "\treturn {\"train\":train_loader, \"test\":test_loader}"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "azMSBtsnCqWr",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "dataset_zip = np.load(\n",
+ " '/content/gdrive/My Drive/data-science/causal-ml/projects/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',\n",
+ " encoding = 'bytes',\n",
+ " allow_pickle=True\n",
+ ")\n",
+ "\n",
+ "imgs = dataset_zip['imgs']\n",
+ "labels = dataset_zip['latents_classes']\n",
+ "label_sizes = dataset_zip['metadata'][()][b'latents_sizes']\n",
+ "label_names = dataset_zip['metadata'][()][b'latents_names']\n",
+ "\n",
+ "# Sample imgs randomly\n",
+ "indices_sampled = np.arange(imgs.shape[0])\n",
+ "np.random.shuffle(indices_sampled)\n",
+ "imgs_sampled = imgs[indices_sampled]\n",
+ "labels_sampled = labels[indices_sampled]\n",
+ "\n",
+ "data_loaders = setup_data_loaders(\n",
+ " imgs_sampled[1000:],\n",
+ " imgs_sampled[:1000],\n",
+ " labels_sampled[1000:],\n",
+ " labels_sampled[:1000],\n",
+ " batch_size=256,\n",
+ " use_cuda=USE_CUDA\n",
+ ")"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Eamf-n9hCqWt",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def train(svi, train_loader, use_cuda=False):\n",
+ "\t# initialize loss accumulator\n",
+ "\tepoch_loss = 0.\n",
+ "\t# do a training epoch over each mini-batch x returned\n",
+ "\t# by the data loader\n",
+ "\tfor xs,ys in train_loader:\n",
+ "\t # if on GPU put mini-batch into CUDA memory\n",
+ "\t if use_cuda:\n",
+ "\t xs = xs.cuda()\n",
+ "\t ys = ys.cuda()\n",
+ "\t # do ELBO gradient and accumulate loss\n",
+ "\t epoch_loss += svi.step(xs, ys)\n",
+ "\t# return epoch loss\n",
+ "\tnormalizer_train = len(train_loader.dataset)\n",
+ "\ttotal_epoch_loss_train = epoch_loss / normalizer_train\n",
+ "\treturn total_epoch_loss_train\n",
+ "\n",
+ "def evaluate(svi, test_loader, use_cuda=False):\n",
+ "\t# initialize loss accumulator\n",
+ "\ttest_loss = 0.\n",
+ "\t# compute the loss over the entire test set\n",
+ "\tfor xs, ys in test_loader:\n",
+ "\t # if on GPU put mini-batch into CUDA memory\n",
+ "\t if use_cuda:\n",
+ "\t xs = xs.cuda()\n",
+ "\t ys = ys.cuda()\n",
+ "\t # compute ELBO estimate and accumulate loss\n",
+ "\t test_loss += svi.evaluate_loss(xs, ys)\n",
+ "\tnormalizer_test = len(test_loader.dataset)\n",
+ "\ttotal_epoch_loss_test = test_loss / normalizer_test\n",
+ "\treturn total_epoch_loss_test"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "LR-hW5UzCqWv",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Run options\n",
+ "LEARNING_RATE = 1.0e-3\n",
+ "\n",
+ "# Run only for a single iteration for testing\n",
+ "NUM_EPOCHS = 10\n",
+ "TEST_FREQUENCY = 5"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "puMbph4idkYg",
+ "colab_type": "code",
+ "outputId": "928e47f6-fd32-49af-a8fa-5fcc4b99a8c1",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ }
+ },
+ "source": [
+ "#################################\n",
+ "### FOR SAVING AND LOADING MODEL\n",
+ "################################\n",
+ "# clear param store\n",
+ "\n",
+ "pyro.clear_param_store()\n",
+ "\n",
+ "network_path = \"/content/gdrive/My Drive/data-science/causal-ml/projects/trained_model.save\"\n",
+ "\n",
+ "PATH = \"trained_model.save\"\n",
+ "\n",
+ "# new model\n",
+ "# vae = CVAE(use_cuda=USE_CUDA)\n",
+ "\n",
+ "# save current model\n",
+ "# torch.save(vae.state_dict(), PATH)\n",
+ "\n",
+ "# to load params from trained model\n",
+ "vae = CVAE(use_cuda=USE_CUDA)\n",
+ "vae.load_state_dict(torch.load(network_path))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 12
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TtCW3_agE7N_",
+ "colab_type": "text"
+ },
+ "source": [
+ "### **DONT RUN THE BELOW CODE AS WE'VE ALREADY TRAINED THE MODEL AND WE'VE STORED THE NETWORK PARAMS**\n",
+ "\n",
+ "## =================================================================================="
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "hIuVKC2lCqWx",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "# clear param store\n",
+ "pyro.clear_param_store()\n",
+ "\n",
+ "# setup the VAE\n",
+ "vae = CVAE(use_cuda=USE_CUDA)\n",
+ "\n",
+ "# setup the optimizer\n",
+ "adam_args = {\"lr\": LEARNING_RATE}\n",
+ "optimizer = Adam(adam_args)\n",
+ "\n",
+ "# setup the inference algorithm\n",
+ "svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())\n",
+ "\n",
+ "train_elbo = []\n",
+ "test_elbo = []\n",
+ "# training loop\n",
+ "\n",
+ "VERBOSE = True\n",
+ "pbar = tqdm(range(NUM_EPOCHS))\n",
+ "for epoch in pbar:\n",
+ " total_epoch_loss_train = train(svi, data_loaders[\"train\"], use_cuda=USE_CUDA)\n",
+ " train_elbo.append(-total_epoch_loss_train)\n",
+ " if VERBOSE:\n",
+ " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, total_epoch_loss_train))\n",
+ " if epoch % TEST_FREQUENCY == 0:\n",
+ " # report test diagnostics\n",
+ " total_epoch_loss_test = evaluate(svi, data_loaders[\"test\"], use_cuda=USE_CUDA)\n",
+ " test_elbo.append(-total_epoch_loss_test)\n",
+ " if VERBOSE:\n",
+ " print(\"[epoch %03d] average test loss: %.4f\" % (epoch, total_epoch_loss_test))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "navbYZazFD0E",
+ "colab_type": "text"
+ },
+ "source": [
+ "## =================================================================================="
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aFnfYMd3CqW0",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Visualizing the reconstruction accuracy of VAE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "IZQ_mi_GMm_v",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "nIt0T2eUCqW0",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "data_iter = iter(data_loaders[\"train\"])\n",
+ "xs, ys = next(data_iter)\n",
+ "\n",
+ "if USE_CUDA:\n",
+ " xs = xs.cuda()\n",
+ " ys = ys.cuda()\n",
+ "rs = vae.reconstruct_image(xs, ys)\n",
+ "if USE_CUDA:\n",
+ " xs = xs.cpu()\n",
+ " rs = rs.cpu()\n",
+ "originals = xs.numpy().reshape(-1, 64,64)\n",
+ "recons = rs.reshape(-1,64,64)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1YBT2lo7CqW2",
+ "colab_type": "code",
+ "outputId": "13118f8b-c9b0-4ebd-98b4-0b404b40074c",
+ "colab": {
+ "resources": {
+ "http://localhost:8080/nbextensions/google.colab/colabwidgets/controls.css": {
+ "data": "/* Copyright (c) Jupyter Development Team.
 * Distributed under the terms of the Modified BSD License.
 */

 /* We import all of these together in a single css file because the Webpack
loader sees only one file at a time. This allows postcss to see the variable
definitions when they are used. */

 /*-----------------------------------------------------------------------------
| Copyright (c) Jupyter Development Team.
| Distributed under the terms of the Modified BSD License.
|----------------------------------------------------------------------------*/

 /*
This file is copied from the JupyterLab project to define default styling for
when the widget styling is compiled down to eliminate CSS variables. We make one
change - we comment out the font import below.
*/

 /**
 * The material design colors are adapted from google-material-color v1.2.6
 * https://github.com/danlevan/google-material-color
 * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/dist/palette.var.css
 *
 * The license for the material design color CSS variables is as follows (see
 * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/LICENSE)
 *
 * The MIT License (MIT)
 *
 * Copyright (c) 2014 Dan Le Van
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

 /*
The following CSS variables define the main, public API for styling JupyterLab.
These variables should be used by all plugins wherever possible. In other
words, plugins should not define custom colors, sizes, etc unless absolutely
necessary. This enables users to change the visual theme of JupyterLab
by changing these variables.

Many variables appear in an ordered sequence (0,1,2,3). These sequences
are designed to work well together, so for example, `--jp-border-color1` should
be used with `--jp-layout-color1`. The numbers have the following meanings:

* 0: super-primary, reserved for special emphasis
* 1: primary, most important under normal situations
* 2: secondary, next most important under normal situations
* 3: tertiary, next most important under normal situations

Throughout JupyterLab, we are mostly following principles from Google's
Material Design when selecting colors. We are not, however, following
all of MD as it is not optimized for dense, information rich UIs.
*/

 /*
 * Optional monospace font for input/output prompt.
 */

 /* Commented out in ipywidgets since we don't need it. */

 /* @import url('https://fonts.googleapis.com/css?family=Roboto+Mono'); */

 /*
 * Added for compabitility with output area
 */

 :root {

  /* Borders

  The following variables, specify the visual styling of borders in JupyterLab.
   */

  /* UI Fonts

  The UI font CSS variables are used for the typography all of the JupyterLab
  user interface elements that are not directly user generated content.
  */ /* Base font size */ /* Ensures px perfect FontAwesome icons */

  /* Use these font colors against the corresponding main layout colors.
     In a light theme, these go from dark to light.
  */

  /* Use these against the brand/accent/warn/error colors.
     These will typically go from light to darker, in both a dark and light theme
   */

  /* Content Fonts

  Content font variables are used for typography of user generated content.
  */ /* Base font size */


  /* Layout

  The following are the main layout colors use in JupyterLab. In a light
  theme these would go from light to dark.
  */

  /* Brand/accent */

  /* State colors (warn, error, success, info) */

  /* Cell specific styles */
  /* A custom blend of MD grey and blue 600
   * See https://meyerweb.com/eric/tools/color-blend/#546E7A:1E88E5:5:hex */
  /* A custom blend of MD grey and orange 600
   * https://meyerweb.com/eric/tools/color-blend/#546E7A:F4511E:5:hex */

  /* Notebook specific styles */

  /* Console specific styles */

  /* Toolbar specific styles */
}

 /* Copyright (c) Jupyter Development Team.
 * Distributed under the terms of the Modified BSD License.
 */

 /*
 * We assume that the CSS variables in
 * https://github.com/jupyterlab/jupyterlab/blob/master/src/default-theme/variables.css
 * have been defined.
 */

 /* This file has code derived from PhosphorJS CSS files, as noted below. The license for this PhosphorJS code is:

Copyright (c) 2014-2017, PhosphorJS Contributors
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

*/

 /*
 * The following section is derived from https://github.com/phosphorjs/phosphor/blob/23b9d075ebc5b73ab148b6ebfc20af97f85714c4/packages/widgets/style/tabbar.css 
 * We've scoped the rules so that they are consistent with exactly our code.
 */

 .jupyter-widgets.widget-tab > .p-TabBar {
  display: -webkit-box;
  display: -ms-flexbox;
  display: flex;
  -webkit-user-select: none;
  -moz-user-select: none;
  -ms-user-select: none;
  user-select: none;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] {
  -webkit-box-orient: horizontal;
  -webkit-box-direction: normal;
      -ms-flex-direction: row;
          flex-direction: row;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] {
  -webkit-box-orient: vertical;
  -webkit-box-direction: normal;
      -ms-flex-direction: column;
          flex-direction: column;
}

 .jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {
  margin: 0;
  padding: 0;
  display: -webkit-box;
  display: -ms-flexbox;
  display: flex;
  -webkit-box-flex: 1;
      -ms-flex: 1 1 auto;
          flex: 1 1 auto;
  list-style-type: none;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] > .p-TabBar-content {
  -webkit-box-orient: horizontal;
  -webkit-box-direction: normal;
      -ms-flex-direction: row;
          flex-direction: row;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] > .p-TabBar-content {
  -webkit-box-orient: vertical;
  -webkit-box-direction: normal;
      -ms-flex-direction: column;
          flex-direction: column;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {
  display: -webkit-box;
  display: -ms-flexbox;
  display: flex;
  -webkit-box-orient: horizontal;
  -webkit-box-direction: normal;
      -ms-flex-direction: row;
          flex-direction: row;
  -webkit-box-sizing: border-box;
          box-sizing: border-box;
  overflow: hidden;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,
.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {
  -webkit-box-flex: 0;
      -ms-flex: 0 0 auto;
          flex: 0 0 auto;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel {
  -webkit-box-flex: 1;
      -ms-flex: 1 1 auto;
          flex: 1 1 auto;
  overflow: hidden;
  white-space: nowrap;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-hidden {
  display: none !important;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab {
  position: relative;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='horizontal'] .p-TabBar-tab {
  left: 0;
  -webkit-transition: left 150ms ease;
  transition: left 150ms ease;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='vertical'] .p-TabBar-tab {
  top: 0;
  -webkit-transition: top 150ms ease;
  transition: top 150ms ease;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab.p-mod-dragging {
  -webkit-transition: none;
  transition: none;
}

 /* End tabbar.css */

 :root { /* margin between inline elements */

    /* From Material Design Lite */
}

 .jupyter-widgets {
    margin: 2px;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    color: black;
    overflow: visible;
}

 .jupyter-widgets.jupyter-widgets-disconnected::before {
    line-height: 28px;
    height: 28px;
}

 .jp-Output-result > .jupyter-widgets {
    margin-left: 0;
    margin-right: 0;
}

 /* vbox and hbox */

 .widget-inline-hbox {
    /* Horizontal widgets */
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: horizontal;
    -webkit-box-direction: normal;
        -ms-flex-direction: row;
            flex-direction: row;
    -webkit-box-align: baseline;
        -ms-flex-align: baseline;
            align-items: baseline;
}

 .widget-inline-vbox {
    /* Vertical Widgets */
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: center;
        -ms-flex-align: center;
            align-items: center;
}

 .widget-box {
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    margin: 0;
    overflow: auto;
}

 .widget-gridbox {
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: grid;
    margin: 0;
    overflow: auto;
}

 .widget-hbox {
    -webkit-box-orient: horizontal;
    -webkit-box-direction: normal;
        -ms-flex-direction: row;
            flex-direction: row;
}

 .widget-vbox {
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
}

 /* General Button Styling */

 .jupyter-button {
    padding-left: 10px;
    padding-right: 10px;
    padding-top: 0px;
    padding-bottom: 0px;
    display: inline-block;
    white-space: nowrap;
    overflow: hidden;
    text-overflow: ellipsis;
    text-align: center;
    font-size: 13px;
    cursor: pointer;

    height: 28px;
    border: 0px solid;
    line-height: 28px;
    -webkit-box-shadow: none;
            box-shadow: none;

    color: rgba(0, 0, 0, .8);
    background-color: #EEEEEE;
    border-color: #E0E0E0;
    border: none;
}

 .jupyter-button i.fa {
    margin-right: 4px;
    pointer-events: none;
}

 .jupyter-button:empty:before {
    content: "\200b"; /* zero-width space */
}

 .jupyter-widgets.jupyter-button:disabled {
    opacity: 0.6;
}

 .jupyter-button i.fa.center {
    margin-right: 0;
}

 .jupyter-button:hover:enabled, .jupyter-button:focus:enabled {
    /* MD Lite 2dp shadow */
    -webkit-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .14),
                0 3px 1px -2px rgba(0, 0, 0, .2),
                0 1px 5px 0 rgba(0, 0, 0, .12);
            box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .14),
                0 3px 1px -2px rgba(0, 0, 0, .2),
                0 1px 5px 0 rgba(0, 0, 0, .12);
}

 .jupyter-button:active, .jupyter-button.mod-active {
    /* MD Lite 4dp shadow */
    -webkit-box-shadow: 0 4px 5px 0 rgba(0, 0, 0, .14),
                0 1px 10px 0 rgba(0, 0, 0, .12),
                0 2px 4px -1px rgba(0, 0, 0, .2);
            box-shadow: 0 4px 5px 0 rgba(0, 0, 0, .14),
                0 1px 10px 0 rgba(0, 0, 0, .12),
                0 2px 4px -1px rgba(0, 0, 0, .2);
    color: rgba(0, 0, 0, .8);
    background-color: #BDBDBD;
}

 .jupyter-button:focus:enabled {
    outline: 1px solid #64B5F6;
}

 /* Button "Primary" Styling */

 .jupyter-button.mod-primary {
    color: rgba(255, 255, 255, 1.0);
    background-color: #2196F3;
}

 .jupyter-button.mod-primary.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #1976D2;
}

 .jupyter-button.mod-primary:active {
    color: rgba(255, 255, 255, 1);
    background-color: #1976D2;
}

 /* Button "Success" Styling */

 .jupyter-button.mod-success {
    color: rgba(255, 255, 255, 1.0);
    background-color: #4CAF50;
}

 .jupyter-button.mod-success.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #388E3C;
 }

 .jupyter-button.mod-success:active {
    color: rgba(255, 255, 255, 1);
    background-color: #388E3C;
 }

 /* Button "Info" Styling */

 .jupyter-button.mod-info {
    color: rgba(255, 255, 255, 1.0);
    background-color: #00BCD4;
}

 .jupyter-button.mod-info.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #0097A7;
}

 .jupyter-button.mod-info:active {
    color: rgba(255, 255, 255, 1);
    background-color: #0097A7;
}

 /* Button "Warning" Styling */

 .jupyter-button.mod-warning {
    color: rgba(255, 255, 255, 1.0);
    background-color: #FF9800;
}

 .jupyter-button.mod-warning.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #F57C00;
}

 .jupyter-button.mod-warning:active {
    color: rgba(255, 255, 255, 1);
    background-color: #F57C00;
}

 /* Button "Danger" Styling */

 .jupyter-button.mod-danger {
    color: rgba(255, 255, 255, 1.0);
    background-color: #F44336;
}

 .jupyter-button.mod-danger.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #D32F2F;
}

 .jupyter-button.mod-danger:active {
    color: rgba(255, 255, 255, 1);
    background-color: #D32F2F;
}

 /* Widget Button*/

 .widget-button, .widget-toggle-button {
    width: 148px;
}

 /* Widget Label Styling */

 /* Override Bootstrap label css */

 .jupyter-widgets label {
    margin-bottom: 0;
    margin-bottom: initial;
}

 .widget-label-basic {
    /* Basic Label */
    color: black;
    font-size: 13px;
    overflow: hidden;
    text-overflow: ellipsis;
    white-space: nowrap;
    line-height: 28px;
}

 .widget-label {
    /* Label */
    color: black;
    font-size: 13px;
    overflow: hidden;
    text-overflow: ellipsis;
    white-space: nowrap;
    line-height: 28px;
}

 .widget-inline-hbox .widget-label {
    /* Horizontal Widget Label */
    color: black;
    text-align: right;
    margin-right: 8px;
    width: 80px;
    -ms-flex-negative: 0;
        flex-shrink: 0;
}

 .widget-inline-vbox .widget-label {
    /* Vertical Widget Label */
    color: black;
    text-align: center;
    line-height: 28px;
}

 /* Widget Readout Styling */

 .widget-readout {
    color: black;
    font-size: 13px;
    height: 28px;
    line-height: 28px;
    overflow: hidden;
    white-space: nowrap;
    text-align: center;
}

 .widget-readout.overflow {
    /* Overflowing Readout */

    /* From Material Design Lite
        shadow-key-umbra-opacity: 0.2;
        shadow-key-penumbra-opacity: 0.14;
        shadow-ambient-shadow-opacity: 0.12;
     */
    -webkit-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .2),
                        0 3px 1px -2px rgba(0, 0, 0, .14),
                        0 1px 5px 0 rgba(0, 0, 0, .12);

    box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .2),
                0 3px 1px -2px rgba(0, 0, 0, .14),
                0 1px 5px 0 rgba(0, 0, 0, .12);
}

 .widget-inline-hbox .widget-readout {
    /* Horizontal Readout */
    text-align: center;
    max-width: 148px;
    min-width: 72px;
    margin-left: 4px;
}

 .widget-inline-vbox .widget-readout {
    /* Vertical Readout */
    margin-top: 4px;
    /* as wide as the widget */
    width: inherit;
}

 /* Widget Checkbox Styling */

 .widget-checkbox {
    width: 300px;
    height: 28px;
    line-height: 28px;
}

 .widget-checkbox input[type="checkbox"] {
    margin: 0px 8px 0px 0px;
    line-height: 28px;
    font-size: large;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 0;
        flex-shrink: 0;
    -ms-flex-item-align: center;
        align-self: center;
}

 /* Widget Valid Styling */

 .widget-valid {
    height: 28px;
    line-height: 28px;
    width: 148px;
    font-size: 13px;
}

 .widget-valid i:before {
    line-height: 28px;
    margin-right: 4px;
    margin-left: 4px;

    /* from the fa class in FontAwesome: https://github.com/FortAwesome/Font-Awesome/blob/49100c7c3a7b58d50baa71efef11af41a66b03d3/css/font-awesome.css#L14 */
    display: inline-block;
    font: normal normal normal 14px/1 FontAwesome;
    font-size: inherit;
    text-rendering: auto;
    -webkit-font-smoothing: antialiased;
    -moz-osx-font-smoothing: grayscale;
}

 .widget-valid.mod-valid i:before {
    content: "\f00c";
    color: green;
}

 .widget-valid.mod-invalid i:before {
    content: "\f00d";
    color: red;
}

 .widget-valid.mod-valid .widget-valid-readout {
    display: none;
}

 /* Widget Text and TextArea Stying */

 .widget-textarea, .widget-text {
    width: 300px;
}

 .widget-text input[type="text"], .widget-text input[type="number"]{
    height: 28px;
    line-height: 28px;
}

 .widget-text input[type="text"]:disabled, .widget-text input[type="number"]:disabled, .widget-textarea textarea:disabled {
    opacity: 0.6;
}

 .widget-text input[type="text"], .widget-text input[type="number"], .widget-textarea textarea {
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    border: 1px solid #9E9E9E;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    padding: 4px 8px;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    -ms-flex-negative: 1;
        flex-shrink: 1;
    outline: none !important;
}

 .widget-textarea textarea {
    height: inherit;
    width: inherit;
}

 .widget-text input:focus, .widget-textarea textarea:focus {
    border-color: #64B5F6;
}

 /* Widget Slider */

 .widget-slider .ui-slider {
    /* Slider Track */
    border: 1px solid #BDBDBD;
    background: #BDBDBD;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    position: relative;
    border-radius: 0px;
}

 .widget-slider .ui-slider .ui-slider-handle {
    /* Slider Handle */
    outline: none !important; /* focused slider handles are colored - see below */
    position: absolute;
    background-color: white;
    border: 1px solid #9E9E9E;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    z-index: 1;
    background-image: none; /* Override jquery-ui */
}

 /* Override jquery-ui */

 .widget-slider .ui-slider .ui-slider-handle:hover, .widget-slider .ui-slider .ui-slider-handle:focus {
    background-color: #2196F3;
    border: 1px solid #2196F3;
}

 .widget-slider .ui-slider .ui-slider-handle:active {
    background-color: #2196F3;
    border-color: #2196F3;
    z-index: 2;
    -webkit-transform: scale(1.2);
            transform: scale(1.2);
}

 .widget-slider  .ui-slider .ui-slider-range {
    /* Interval between the two specified value of a double slider */
    position: absolute;
    background: #2196F3;
    z-index: 0;
}

 /* Shapes of Slider Handles */

 .widget-hslider .ui-slider .ui-slider-handle {
    width: 16px;
    height: 16px;
    margin-top: -7px;
    margin-left: -7px;
    border-radius: 50%;
    top: 0;
}

 .widget-vslider .ui-slider .ui-slider-handle {
    width: 16px;
    height: 16px;
    margin-bottom: -7px;
    margin-left: -7px;
    border-radius: 50%;
    left: 0;
}

 .widget-hslider .ui-slider .ui-slider-range {
    height: 8px;
    margin-top: -3px;
}

 .widget-vslider .ui-slider .ui-slider-range {
    width: 8px;
    margin-left: -3px;
}

 /* Horizontal Slider */

 .widget-hslider {
    width: 300px;
    height: 28px;
    line-height: 28px;

    /* Override the align-items baseline. This way, the description and readout
    still seem to align their baseline properly, and we don't have to have
    align-self: stretch in the .slider-container. */
    -webkit-box-align: center;
        -ms-flex-align: center;
            align-items: center;
}

 .widgets-slider .slider-container {
    overflow: visible;
}

 .widget-hslider .slider-container {
    height: 28px;
    margin-left: 6px;
    margin-right: 6px;
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
}

 .widget-hslider .ui-slider {
    /* Inner, invisible slide div */
    height: 4px;
    margin-top: 12px;
    width: 100%;
}

 /* Vertical Slider */

 .widget-vbox .widget-label {
    height: 28px;
    line-height: 28px;
}

 .widget-vslider {
    /* Vertical Slider */
    height: 200px;
    width: 72px;
}

 .widget-vslider .slider-container {
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
    margin-left: auto;
    margin-right: auto;
    margin-bottom: 6px;
    margin-top: 6px;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
}

 .widget-vslider .ui-slider-vertical {
    /* Inner, invisible slide div */
    width: 4px;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    margin-left: auto;
    margin-right: auto;
}

 /* Widget Progress Styling */

 .progress-bar {
    -webkit-transition: none;
    transition: none;
}

 .progress-bar {
    height: 28px;
}

 .progress-bar {
    background-color: #2196F3;
}

 .progress-bar-success {
    background-color: #4CAF50;
}

 .progress-bar-info {
    background-color: #00BCD4;
}

 .progress-bar-warning {
    background-color: #FF9800;
}

 .progress-bar-danger {
    background-color: #F44336;
}

 .progress {
    background-color: #EEEEEE;
    border: none;
    -webkit-box-shadow: none;
            box-shadow: none;
}

 /* Horisontal Progress */

 .widget-hprogress {
    /* Progress Bar */
    height: 28px;
    line-height: 28px;
    width: 300px;
    -webkit-box-align: center;
        -ms-flex-align: center;
            align-items: center;

}

 .widget-hprogress .progress {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    margin-top: 4px;
    margin-bottom: 4px;
    -ms-flex-item-align: stretch;
        align-self: stretch;
    /* Override bootstrap style */
    height: auto;
    height: initial;
}

 /* Vertical Progress */

 .widget-vprogress {
    height: 200px;
    width: 72px;
}

 .widget-vprogress .progress {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    width: 20px;
    margin-left: auto;
    margin-right: auto;
    margin-bottom: 0;
}

 /* Select Widget Styling */

 .widget-dropdown {
    height: 28px;
    width: 300px;
    line-height: 28px;
}

 .widget-dropdown > select {
    padding-right: 20px;
    border: 1px solid #9E9E9E;
    border-radius: 0;
    height: inherit;
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    outline: none !important;
    -webkit-box-shadow: none;
            box-shadow: none;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    vertical-align: top;
    padding-left: 8px;
	appearance: none;
	-webkit-appearance: none;
	-moz-appearance: none;
    background-repeat: no-repeat;
	background-size: 20px;
	background-position: right center;
    background-image: url("data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0idXRmLTgiPz4KPCEtLSBHZW5lcmF0b3I6IEFkb2JlIElsbHVzdHJhdG9yIDE5LjIuMSwgU1ZHIEV4cG9ydCBQbHVnLUluIC4gU1ZHIFZlcnNpb246IDYuMDAgQnVpbGQgMCkgIC0tPgo8c3ZnIHZlcnNpb249IjEuMSIgaWQ9IkxheWVyXzEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgeG1sbnM6eGxpbms9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkveGxpbmsiIHg9IjBweCIgeT0iMHB4IgoJIHZpZXdCb3g9IjAgMCAxOCAxOCIgc3R5bGU9ImVuYWJsZS1iYWNrZ3JvdW5kOm5ldyAwIDAgMTggMTg7IiB4bWw6c3BhY2U9InByZXNlcnZlIj4KPHN0eWxlIHR5cGU9InRleHQvY3NzIj4KCS5zdDB7ZmlsbDpub25lO30KPC9zdHlsZT4KPHBhdGggZD0iTTUuMiw1LjlMOSw5LjdsMy44LTMuOGwxLjIsMS4ybC00LjksNWwtNC45LTVMNS4yLDUuOXoiLz4KPHBhdGggY2xhc3M9InN0MCIgZD0iTTAtMC42aDE4djE4SDBWLTAuNnoiLz4KPC9zdmc+Cg");
}

 .widget-dropdown > select:focus {
    border-color: #64B5F6;
}

 .widget-dropdown > select:disabled {
    opacity: 0.6;
}

 /* To disable the dotted border in Firefox around select controls.
   See http://stackoverflow.com/a/18853002 */

 .widget-dropdown > select:-moz-focusring {
    color: transparent;
    text-shadow: 0 0 0 #000;
}

 /* Select and SelectMultiple */

 .widget-select {
    width: 300px;
    line-height: 28px;

    /* Because Firefox defines the baseline of a select as the bottom of the
    control, we align the entire control to the top and add padding to the
    select to get an approximate first line baseline alignment. */
    -webkit-box-align: start;
        -ms-flex-align: start;
            align-items: flex-start;
}

 .widget-select > select {
    border: 1px solid #9E9E9E;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
    outline: none !important;
    overflow: auto;
    height: inherit;

    /* Because Firefox defines the baseline of a select as the bottom of the
    control, we align the entire control to the top and add padding to the
    select to get an approximate first line baseline alignment. */
    padding-top: 5px;
}

 .widget-select > select:focus {
    border-color: #64B5F6;
}

 .wiget-select > select > option {
    padding-left: 4px;
    line-height: 28px;
    /* line-height doesn't work on some browsers for select options */
    padding-top: calc(28px - var(--jp-widgets-font-size) / 2);
    padding-bottom: calc(28px - var(--jp-widgets-font-size) / 2);
}

 /* Toggle Buttons Styling */

 .widget-toggle-buttons {
    line-height: 28px;
}

 .widget-toggle-buttons .widget-toggle-button {
    margin-left: 2px;
    margin-right: 2px;
}

 .widget-toggle-buttons .jupyter-button:disabled {
    opacity: 0.6;
}

 /* Radio Buttons Styling */

 .widget-radio {
    width: 300px;
    line-height: 28px;
}

 .widget-radio-box {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    margin-bottom: 8px;
}

 .widget-radio-box label {
    height: 20px;
    line-height: 20px;
    font-size: 13px;
}

 .widget-radio-box input {
    height: 20px;
    line-height: 20px;
    margin: 0 8px 0 1px;
    float: left;
}

 /* Color Picker Styling */

 .widget-colorpicker {
    width: 300px;
    height: 28px;
    line-height: 28px;
}

 .widget-colorpicker > .widget-colorpicker-input {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 1;
        flex-shrink: 1;
    min-width: 72px;
}

 .widget-colorpicker input[type="color"] {
    width: 28px;
    height: 28px;
    padding: 0 2px; /* make the color square actually square on Chrome on OS X */
    background: white;
    color: rgba(0, 0, 0, .8);
    border: 1px solid #9E9E9E;
    border-left: none;
    -webkit-box-flex: 0;
        -ms-flex-positive: 0;
            flex-grow: 0;
    -ms-flex-negative: 0;
        flex-shrink: 0;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    -ms-flex-item-align: stretch;
        align-self: stretch;
    outline: none !important;
}

 .widget-colorpicker.concise input[type="color"] {
    border-left: 1px solid #9E9E9E;
}

 .widget-colorpicker input[type="color"]:focus, .widget-colorpicker input[type="text"]:focus {
    border-color: #64B5F6;
}

 .widget-colorpicker input[type="text"] {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    outline: none !important;
    height: 28px;
    line-height: 28px;
    background: white;
    color: rgba(0, 0, 0, .8);
    border: 1px solid #9E9E9E;
    font-size: 13px;
    padding: 4px 8px;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    -ms-flex-negative: 1;
        flex-shrink: 1;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
}

 .widget-colorpicker input[type="text"]:disabled {
    opacity: 0.6;
}

 /* Date Picker Styling */

 .widget-datepicker {
    width: 300px;
    height: 28px;
    line-height: 28px;
}

 .widget-datepicker input[type="date"] {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 1;
        flex-shrink: 1;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    outline: none !important;
    height: 28px;
    border: 1px solid #9E9E9E;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    padding: 4px 8px;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
}

 .widget-datepicker input[type="date"]:focus {
    border-color: #64B5F6;
}

 .widget-datepicker input[type="date"]:invalid {
    border-color: #FF9800;
}

 .widget-datepicker input[type="date"]:disabled {
    opacity: 0.6;
}

 /* Play Widget */

 .widget-play {
    width: 148px;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
}

 .widget-play .jupyter-button {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    height: auto;
}

 .widget-play .jupyter-button:disabled {
    opacity: 0.6;
}

 /* Tab Widget */

 .jupyter-widgets.widget-tab {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
}

 .jupyter-widgets.widget-tab > .p-TabBar {
    /* Necessary so that a tab can be shifted down to overlay the border of the box below. */
    overflow-x: visible;
    overflow-y: visible;
}

 .jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {
    /* Make sure that the tab grows from bottom up */
    -webkit-box-align: end;
        -ms-flex-align: end;
            align-items: flex-end;
    min-width: 0;
    min-height: 0;
}

 .jupyter-widgets.widget-tab > .widget-tab-contents {
    width: 100%;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    margin: 0;
    background: white;
    color: rgba(0, 0, 0, .8);
    border: 1px solid #9E9E9E;
    padding: 15px;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    overflow: auto;
}

 .jupyter-widgets.widget-tab > .p-TabBar {
    font: 13px Helvetica, Arial, sans-serif;
    min-height: 25px;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {
    -webkit-box-flex: 0;
        -ms-flex: 0 1 144px;
            flex: 0 1 144px;
    min-width: 35px;
    min-height: 25px;
    line-height: 24px;
    margin-left: -1px;
    padding: 0px 10px;
    background: #EEEEEE;
    color: rgba(0, 0, 0, .5);
    border: 1px solid #9E9E9E;
    border-bottom: none;
    position: relative;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current {
    color: rgba(0, 0, 0, 1.0);
    /* We want the background to match the tab content background */
    background: white;
    min-height: 26px;
    -webkit-transform: translateY(1px);
            transform: translateY(1px);
    overflow: visible;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current:before {
    position: absolute;
    top: -1px;
    left: -1px;
    content: '';
    height: 2px;
    width: calc(100% + 2px);
    background: #2196F3;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:first-child {
    margin-left: 0;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:hover:not(.p-mod-current) {
    background: white;
    color: rgba(0, 0, 0, .8);
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon {
    margin-left: 4px;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon:before {
    font-family: FontAwesome;
    content: '\f00d'; /* close */
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,
.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel,
.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {
    line-height: 24px;
}

 /* Accordion Widget */

 .p-Collapse {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
}

 .p-Collapse-header {
    padding: 4px;
    cursor: pointer;
    color: rgba(0, 0, 0, .5);
    background-color: #EEEEEE;
    border: 1px solid #9E9E9E;
    padding: 10px 15px;
    font-weight: bold;
}

 .p-Collapse-header:hover {
    background-color: white;
    color: rgba(0, 0, 0, .8);
}

 .p-Collapse-open > .p-Collapse-header {
    background-color: white;
    color: rgba(0, 0, 0, 1.0);
    cursor: default;
    border-bottom: none;
}

 .p-Collapse .p-Collapse-header::before {
    content: '\f0da\00A0';  /* caret-right, non-breaking space */
    display: inline-block;
    font: normal normal normal 14px/1 FontAwesome;
    font-size: inherit;
    text-rendering: auto;
    -webkit-font-smoothing: antialiased;
    -moz-osx-font-smoothing: grayscale;
}

 .p-Collapse-open > .p-Collapse-header::before {
    content: '\f0d7\00A0'; /* caret-down, non-breaking space */
}

 .p-Collapse-contents {
    padding: 15px;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    border-left: 1px solid #9E9E9E;
    border-right: 1px solid #9E9E9E;
    border-bottom: 1px solid #9E9E9E;
    overflow: auto;
}

 .p-Accordion {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
}

 .p-Accordion .p-Collapse {
    margin-bottom: 0;
}

 .p-Accordion .p-Collapse + .p-Collapse {
    margin-top: 4px;
}

 /* HTML widget */

 .widget-html, .widget-htmlmath {
    font-size: 13px;
}

 .widget-html > .widget-html-content, .widget-htmlmath > .widget-html-content {
    /* Fill out the area in the HTML widget */
    -ms-flex-item-align: stretch;
        align-self: stretch;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 1;
        flex-shrink: 1;
    /* Makes sure the baseline is still aligned with other elements */
    line-height: 28px;
    /* Make it possible to have absolutely-positioned elements in the html */
    position: relative;
}

/*# sourceMappingURL=data:application/json;base64,{"version":3,"sources":["../node_modules/@jupyter-widgets/controls/css/widgets.css","../node_modules/@jupyter-widgets/controls/css/labvariables.css","../node_modules/@jupyter-widgets/controls/css/materialcolors.css","../node_modules/@jupyter-widgets/controls/css/widgets-base.css","../node_modules/@jupyter-widgets/controls/css/phosphor.css"],"names":[],"mappings":"AAAA;;GAEG;;CAEF;;kCAEiC;;CCNlC;;;+EAG+E;;CAE/E;;;;EAIE;;CCTF;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;;CDhBH;;;;;;;;;;;;;;;;;;;EAmBE;;CAGF;;GAEG;;CACF,yDAAyD;;CAC1D,yEAAyE;;CAEzE;;GAEG;;CAOH;;EAEE;;;KAGG;;EAQH;;;;IAIE,CAIwB,oBAAoB,CAGhB,0CAA0C;;EAGxE;;IAEE;;EAOF;;KAEG;;EAOH;;;IAGE,CAWwB,oBAAoB;;;EAU9C;;;;IAIE;;EAOF,kBAAkB;;EAYlB,+CAA+C;;EAsB/C,0BAA0B;EAa1B;4EAC0E;EAE1E;wEACsE;;EAGtE,8BAA8B;;EAK9B,6BAA6B;;EAI7B,6BAA6B;CAQ9B;;CEzMD;;GAEG;;CAEH;;;;GAIG;;CCRH;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;EA8BE;;CAEF;;;GAGG;;CAEH;EACE,qBAAc;EAAd,qBAAc;EAAd,cAAc;EACd,0BAA0B;EAC1B,uBAAuB;EACvB,sBAAsB;EACtB,kBAAkB;CACnB;;CAGD;EACE,+BAAoB;EAApB,8BAAoB;MAApB,wBAAoB;UAApB,oBAAoB;CACrB;;CAGD;EACE,6BAAuB;EAAvB,8BAAuB;MAAvB,2BAAuB;UAAvB,uBAAuB;CACxB;;CAGD;EACE,UAAU;EACV,WAAW;EACX,qBAAc;EAAd,qBAAc;EAAd,cAAc;EACd,oBAAe;MAAf,mBAAe;UAAf,eAAe;EACf,sBAAsB;CACvB;;CAGD;EACE,+BAAoB;EAApB,8BAAoB;MAApB,wBAAoB;UAApB,oBAAoB;CACrB;;CAGD;EACE,6BAAuB;EAAvB,8BAAuB;MAAvB,2BAAuB;UAAvB,uBAAuB;CACxB;;CAGD;EACE,qBAAc;EAAd,qBAAc;EAAd,cAAc;EACd,+BAAoB;EAApB,8BAAoB;MAApB,wBAAoB;UAApB,oBAAoB;EACpB,+BAAuB;UAAvB,uBAAuB;EACvB,iBAAiB;CAClB;;CAGD;;EAEE,oBAAe;MAAf,mBAAe;UAAf,eAAe;CAChB;;CAGD;EACE,oBAAe;MAAf,mBAAe;UAAf,eAAe;EACf,iBAAiB;EACjB,oBAAoB;CACrB;;CAGD;EACE,yBAAyB;CAC1B;;CAGD;EACE,mBAAmB;CACpB;;CAGD;EACE,QAAQ;EACR,oCAA4B;EAA5B,4BAA4B;CAC7B;;CAGD;EACE,OAAO;EACP,mCAA2B;EAA3B,2BAA2B;CAC5B;;CAGD;EACE,yBAAiB;EAAjB,iBAAiB;CAClB;;CAED,oBAAoB;;CD9GpB,QAUqC,oCAAoC;;IA2BrE,+BAA+B;CAIlC;;CAED;IACI,YAAiC;IACjC,+BAAuB;YAAvB,uBAAuB;IACvB,aAA+B;IAC/B,kBAAkB;CACrB;;CAED;IACI,kBAA6C;IAC7C,aAAwC;CAC3C;;CAED;IACI,eAAe;IACf,gBAAgB;CACnB;;CAED,mBAAmB;;CAEnB;IACI,wBAAwB;IACxB,+BAAuB;YAAvB,uBAAuB;IACvB,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,+BAAoB;IAApB,8BAAoB;QAApB,wBAAoB;YAApB,oBAAoB;IACpB,4BAAsB;QAAtB,yBAAsB;YAAtB,sBAAsB;CACzB;;CAED;IACI,sBAAsB;IACtB,+BAAuB;YAAvB,uBAAuB;IACvB,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,0BAAoB;QAApB,uBAAoB;YAApB,oBAAoB;CACvB;;CAED;IACI,+BAAuB;YAAvB,uBAAuB;IACvB,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,UAAU;IACV,eAAe;CAClB;;CAED;IACI,+BAAuB;YAAvB,uBAAuB;IACvB,cAAc;IACd,UAAU;IACV,eAAe;CAClB;;CAED;IACI,+BAAoB;IAApB,8BAAoB;QAApB,wBAAoB;YAApB,oBAAoB;CACvB;;CAED;IACI,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;CAC1B;;CAED,4BAA4B;;CAE5B;IACI,mBAAmB;IACnB,oBAAoB;IACpB,iBAAiB;IACjB,oBAAoB;IACpB,sBAAsB;IACtB,oBAAoB;IACpB,iBAAiB;IACjB,wBAAwB;IACxB,mBAAmB;IACnB,gBAAuC;IACvC,gBAAgB;;IAEhB,aAAwC;IACxC,kBAAkB;IAClB,kBAA6C;IAC7C,yBAAiB;YAAjB,iBAAiB;;IAEjB,yBAAgC;IAChC,0BAA0C;IAC1C,sBAAsC;IACtC,aAAa;CAChB;;CAED;IACI,kBAA8C;IAC9C,qBAAqB;CACxB;;CAED;IACI,iBAAiB,CAAC,sBAAsB;CAC3C;;CAED;IACI,aAA4C;CAC/C;;CAED;IACI,gBAAgB;CACnB;;CAED;IACI,wBAAwB;IACxB;;+CAE+E;YAF/E;;+CAE+E;CAClF;;CAED;IACI,wBAAwB;IACxB;;iDAE6E;YAF7E;;iDAE6E;IAC7E,yBAAgC;IAChC,0BAA0C;CAC7C;;CAED;IACI,2BAA8D;CACjE;;CAED,8BAA8B;;CAE9B;IACI,gCAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED,8BAA8B;;CAE9B;IACI,gCAAwC;IACxC,0BAA2C;CAC9C;;CAED;IACI,8BAAwC;IACxC,0BAA2C;EAC7C;;CAEF;IACI,8BAAwC;IACxC,0BAA2C;EAC7C;;CAED,2BAA2B;;CAE5B;IACI,gCAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED,8BAA8B;;CAE9B;IACI,gCAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED,6BAA6B;;CAE7B;IACI,gCAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED,kBAAkB;;CAElB;IACI,aAA4C;CAC/C;;CAED,0BAA0B;;CAE1B,kCAAkC;;CAClC;IACI,iBAAuB;IAAvB,uBAAuB;CAC1B;;CAED;IACI,iBAAiB;IACjB,aAAqC;IACrC,gBAAuC;IACvC,iBAAiB;IACjB,wBAAwB;IACxB,oBAAoB;IACpB,kBAA6C;CAChD;;CAED;IACI,WAAW;IACX,aAAqC;IACrC,gBAAuC;IACvC,iBAAiB;IACjB,wBAAwB;IACxB,oBAAoB;IACpB,kBAA6C;CAChD;;CAED;IACI,6BAA6B;IAC7B,aAAqC;IACrC,kBAAkB;IAClB,kBAA0D;IAC1D,YAA4C;IAC5C,qBAAe;QAAf,eAAe;CAClB;;CAED;IACI,2BAA2B;IAC3B,aAAqC;IACrC,mBAAmB;IACnB,kBAA6C;CAChD;;CAED,4BAA4B;;CAE5B;IACI,aAAuC;IACvC,gBAAuC;IACvC,aAAwC;IACxC,kBAA6C;IAC7C,iBAAiB;IACjB,oBAAoB;IACpB,mBAAmB;CACtB;;CAED;IACI,yBAAyB;;IAEzB;;;;OAIG;IACH;;uDAEoD;;IAMpD;;+CAE4C;CAC/C;;CAED;IACI,wBAAwB;IACxB,mBAAmB;IACnB,iBAAgD;IAChD,gBAA+C;IAC/C,iBAA6C;CAChD;;CAED;IACI,sBAAsB;IACtB,gBAA4C;IAC5C,2BAA2B;IAC3B,eAAe;CAClB;;CAED,6BAA6B;;CAE7B;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,wBAAgE;IAChE,kBAA6C;IAC7C,iBAAiB;IACjB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,4BAAmB;QAAnB,mBAAmB;CACtB;;CAED,0BAA0B;;CAE1B;IACI,aAAwC;IACxC,kBAA6C;IAC7C,aAA4C;IAC5C,gBAAuC;CAC1C;;CAED;IACI,kBAA6C;IAC7C,kBAA8C;IAC9C,iBAA6C;;IAE7C,0JAA0J;IAC1J,sBAAsB;IACtB,8CAA8C;IAC9C,mBAAmB;IACnB,qBAAqB;IACrB,oCAAoC;IACpC,mCAAmC;CACtC;;CAED;IACI,iBAAiB;IACjB,aAAa;CAChB;;CAED;IACI,iBAAiB;IACjB,WAAW;CACd;;CAED;IACI,cAAc;CACjB;;CAED,qCAAqC;;CAErC;IACI,aAAsC;CACzC;;CAED;IACI,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,aAA4C;CAC/C;;CAED;IACI,+BAAuB;YAAvB,uBAAuB;IACvB,0BAAwF;IACxF,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,iBAAsF;IACtF,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,aAAa,CAAC,iEAAiE;IAC/E,qBAAe;QAAf,eAAe;IACf,yBAAyB;CAC5B;;CAED;IACI,gBAAgB;IAChB,eAAe;CAClB;;CAED;IACI,sBAAyD;CAC5D;;CAED,mBAAmB;;CAEnB;IACI,kBAAkB;IAClB,0BAA4E;IAC5E,oBAAoC;IACpC,+BAAuB;YAAvB,uBAAuB;IACvB,mBAAmB;IACnB,mBAAmB;CACtB;;CAED;IACI,mBAAmB;IACnB,yBAAyB,CAAC,oDAAoD;IAC9E,mBAAmB;IACnB,wBAAmE;IACnE,0BAAiG;IACjG,+BAAuB;YAAvB,uBAAuB;IACvB,WAAW;IACX,uBAAuB,CAAC,wBAAwB;CACnD;;CAED,wBAAwB;;CACxB;IACI,0BAA+D;IAC/D,0BAAiG;CACpG;;CAED;IACI,0BAA+D;IAC/D,sBAA2D;IAC3D,WAAW;IACX,8BAAsB;YAAtB,sBAAsB;CACzB;;CAED;IACI,iEAAiE;IACjE,mBAAmB;IACnB,oBAAyD;IACzD,WAAW;CACd;;CAED,8BAA8B;;CAE9B;IACI,YAA4C;IAC5C,aAA6C;IAC7C,iBAAgJ;IAChJ,kBAAqG;IACrG,mBAAmB;IACnB,OAAO;CACV;;CAED;IACI,YAA4C;IAC5C,aAA6C;IAC7C,oBAAuG;IACvG,kBAAiJ;IACjJ,mBAAmB;IACnB,QAAQ;CACX;;CAED;IACI,YAA6D;IAC7D,iBAAyJ;CAC5J;;CAED;IACI,WAA4D;IAC5D,kBAA0J;CAC7J;;CAED,uBAAuB;;CAEvB;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;;IAE7C;;oDAEgD;IAChD,0BAAoB;QAApB,uBAAoB;YAApB,oBAAoB;CACvB;;CAED;IACI,kBAAkB;CACrB;;CAED;IACI,aAAwC;IACxC,iBAAwG;IACxG,kBAAyG;IACzG,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;CAClD;;CAED;IACI,gCAAgC;IAChC,YAAiD;IACjD,iBAAmG;IACnG,YAAY;CACf;;CAED,qBAAqB;;CAErB;IACI,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,qBAAqB;IACrB,cAA0C;IAC1C,YAA2C;CAC9C;;CAED;IACI,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;IAC/C,kBAAkB;IAClB,mBAAmB;IACnB,mBAA0G;IAC1G,gBAAuG;IACvG,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,gCAAgC;IAChC,WAAgD;IAChD,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,kBAAkB;IAClB,mBAAmB;CACtB;;CAED,6BAA6B;;CAE7B;IACI,yBAAyB;IAIzB,iBAAiB;CACpB;;CAED;IACI,aAAwC;CAC3C;;CAED;IACI,0BAAyC;CAC5C;;CAED;IACI,0BAA2C;CAC9C;;CAED;IACI,0BAAwC;CAC3C;;CAED;IACI,0BAAwC;CAC3C;;CAED;IACI,0BAAyC;CAC5C;;CAED;IACI,0BAA0C;IAC1C,aAAa;IACb,yBAAiB;YAAjB,iBAAiB;CACpB;;CAED,yBAAyB;;CAEzB;IACI,kBAAkB;IAClB,aAAwC;IACxC,kBAA6C;IAC7C,aAAsC;IACtC,0BAAoB;QAApB,uBAAoB;YAApB,oBAAoB;;CAEvB;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,gBAA4C;IAC5C,mBAA+C;IAC/C,6BAAoB;QAApB,oBAAoB;IACpB,8BAA8B;IAC9B,aAAgB;IAAhB,gBAAgB;CACnB;;CAED,uBAAuB;;CAEvB;IACI,cAA0C;IAC1C,YAA2C;CAC9C;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,YAA4C;IAC5C,kBAAkB;IAClB,mBAAmB;IACnB,iBAAiB;CACpB;;CAED,2BAA2B;;CAE3B;IACI,aAAwC;IACxC,aAAsC;IACtC,kBAA6C;CAChD;;CAED;IACI,oBAAoB;IACpB,0BAAwF;IACxF,iBAAiB;IACjB,gBAAgB;IAChB,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;IAC/C,aAAa,CAAC,iEAAiE;IAC/E,+BAAuB;YAAvB,uBAAuB;IACvB,yBAAyB;IACzB,yBAAiB;YAAjB,iBAAiB;IACjB,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,oBAAoB;IACpB,kBAAyD;CAC5D,iBAAiB;CACjB,yBAAyB;CACzB,sBAAsB;IACnB,6BAA6B;CAChC,sBAAsB;CACtB,kCAAkC;IAC/B,kuBAAmD;CACtD;;CACD;IACI,sBAAyD;CAC5D;;CAED;IACI,aAA4C;CAC/C;;CAED;6CAC6C;;CAC7C;IACI,mBAAmB;IACnB,wBAAwB;CAC3B;;CAED,+BAA+B;;CAE/B;IACI,aAAsC;IACtC,kBAA6C;;IAE7C;;kEAE8D;IAC9D,yBAAwB;QAAxB,sBAAwB;YAAxB,wBAAwB;CAC3B;;CAED;IACI,0BAAwF;IACxF,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;IAC/C,yBAAyB;IACzB,eAAe;IACf,gBAAgB;;IAEhB;;kEAE8D;IAC9D,iBAAiB;CACpB;;CAED;IACI,sBAAyD;CAC5D;;CAED;IACI,kBAA8C;IAC9C,kBAA6C;IAC7C,kEAAkE;IAClE,0DAAiF;IACjF,6DAAoF;CACvF;;CAID,4BAA4B;;CAE5B;IACI,kBAA6C;CAChD;;CAED;IACI,iBAAsC;IACtC,kBAAuC;CAC1C;;CAED;IACI,aAA4C;CAC/C;;CAED,2BAA2B;;CAE3B;IACI,aAAsC;IACtC,kBAA6C;CAChD;;CAED;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;IACrB,+BAAuB;YAAvB,uBAAuB;IACvB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,mBAA8D;CACjE;;CAED;IACI,aAA4C;IAC5C,kBAAiD;IACjD,gBAAuC;CAC1C;;CAED;IACI,aAA4C;IAC5C,kBAAiD;IACjD,oBAA4D;IAC5D,YAAY;CACf;;CAED,0BAA0B;;CAE1B;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,gBAA+C;CAClD;;CAED;IACI,YAAuC;IACvC,aAAwC;IACxC,eAAe,CAAC,6DAA6D;IAC7E,kBAAqD;IACrD,yBAAqC;IACrC,0BAAwF;IACxF,kBAAkB;IAClB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,+BAAuB;YAAvB,uBAAuB;IACvB,6BAAoB;QAApB,oBAAoB;IACpB,yBAAyB;CAC5B;;CAED;IACI,+BAA6F;CAChG;;CAED;IACI,sBAAyD;CAC5D;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,yBAAyB;IACzB,aAAwC;IACxC,kBAA6C;IAC7C,kBAAqD;IACrD,yBAAqC;IACrC,0BAAwF;IACxF,gBAAuC;IACvC,iBAAsF;IACtF,aAAa,CAAC,iEAAiE;IAC/E,qBAAe;QAAf,eAAe;IACf,+BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,aAA4C;CAC/C;;CAED,yBAAyB;;CAEzB;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,aAAa,CAAC,iEAAiE;IAC/E,yBAAyB;IACzB,aAAwC;IACxC,0BAAwF;IACxF,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,iBAAsF;IACtF,+BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,sBAAyD;CAC5D;;CAED;IACI,sBAAoC;CACvC;;CAED;IACI,aAA4C;CAC/C;;CAED,iBAAiB;;CAEjB;IACI,aAA4C;IAC5C,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;CACxB;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,aAAa;CAChB;;CAED;IACI,aAA4C;CAC/C;;CAED,gBAAgB;;CAEhB;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,yFAAyF;IACzF,oBAAoB;IACpB,oBAAoB;CACvB;;CAED;IACI,iDAAiD;IACjD,uBAAsB;QAAtB,oBAAsB;YAAtB,sBAAsB;IACtB,aAAa;IACb,cAAc;CACjB;;CAED;IACI,YAAY;IACZ,+BAAuB;YAAvB,uBAAuB;IACvB,UAAU;IACV,kBAAoC;IACpC,yBAAgC;IAChC,0BAA6D;IAC7D,cAA6C;IAC7C,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,eAAe;CAClB;;CAED;IACI,wCAA+D;IAC/D,iBAAmF;CACtF;;CAED;IACI,oBAAiD;QAAjD,oBAAiD;YAAjD,gBAAiD;IACjD,gBAAgB;IAChB,iBAAmF;IACnF,kBAAqD;IACrD,kBAA+C;IAC/C,kBAAkB;IAClB,oBAAoC;IACpC,yBAAgC;IAChC,0BAA6D;IAC7D,oBAAoB;IACpB,mBAAmB;CACtB;;CAED;IACI,0BAAgC;IAChC,gEAAgE;IAChE,kBAAoC;IACpC,iBAAuF;IACvF,mCAA8C;YAA9C,2BAA8C;IAC9C,kBAAkB;CACrB;;CAED;IACI,mBAAmB;IACnB,UAAuC;IACvC,WAAwC;IACxC,YAAY;IACZ,YAAoD;IACpD,wBAA+C;IAC/C,oBAAmC;CACtC;;CAED;IACI,eAAe;CAClB;;CAED;IACI,kBAAoC;IACpC,yBAAgC;CACnC;;CAED;IACI,iBAAiB;CACpB;;CAED;IACI,yBAAyB;IACzB,iBAAiB,CAAC,WAAW;CAChC;;CAED;;;IAGI,kBAAqD;CACxD;;CAED,sBAAsB;;CAEtB;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;CACxB;;CAED;IACI,aAAyC;IACzC,gBAAgB;IAChB,yBAAgC;IAChC,0BAA0C;IAC1C,0BAAqE;IACrE,mBAA+F;IAC/F,kBAAkB;CACrB;;CAED;IACI,wBAA0C;IAC1C,yBAAgC;CACnC;;CAED;IACI,wBAA0C;IAC1C,0BAAgC;IAChC,gBAAgB;IAChB,oBAAoB;CACvB;;CAED;IACI,sBAAsB,EAAE,qCAAqC;IAC7D,sBAAsB;IACtB,8CAA8C;IAC9C,mBAAmB;IACnB,qBAAqB;IACrB,oCAAoC;IACpC,mCAAmC;CACtC;;CAED;IACI,sBAAsB,CAAC,oCAAoC;CAC9D;;CAED;IACI,cAA6C;IAC7C,wBAA0C;IAC1C,yBAAgC;IAChC,+BAA0E;IAC1E,gCAA2E;IAC3E,iCAA4E;IAC5E,eAAe;CAClB;;CAED;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;CACxB;;CAED;IACI,iBAAiB;CACpB;;CAED;IACI,gBAAgB;CACnB;;CAID,iBAAiB;;CAEjB;IACI,gBAAuC;CAC1C;;CAED;IACI,0CAA0C;IAC1C,6BAAoB;QAApB,oBAAoB;IACpB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,kEAAkE;IAClE,kBAA6C;IAC7C,yEAAyE;IACzE,mBAAmB;CACtB","file":"controls.css","sourcesContent":["/* Copyright (c) Jupyter Development Team.\n * Distributed under the terms of the Modified BSD License.\n */\n\n /* We import all of these together in a single css file because the Webpack\nloader sees only one file at a time. This allows postcss to see the variable\ndefinitions when they are used. */\n\n@import \"./labvariables.css\";\n@import \"./widgets-base.css\";\n","/*-----------------------------------------------------------------------------\n| Copyright (c) Jupyter Development Team.\n| Distributed under the terms of the Modified BSD License.\n|----------------------------------------------------------------------------*/\n\n/*\nThis file is copied from the JupyterLab project to define default styling for\nwhen the widget styling is compiled down to eliminate CSS variables. We make one\nchange - we comment out the font import below.\n*/\n\n@import \"./materialcolors.css\";\n\n/*\nThe following CSS variables define the main, public API for styling JupyterLab.\nThese variables should be used by all plugins wherever possible. In other\nwords, plugins should not define custom colors, sizes, etc unless absolutely\nnecessary. This enables users to change the visual theme of JupyterLab\nby changing these variables.\n\nMany variables appear in an ordered sequence (0,1,2,3). These sequences\nare designed to work well together, so for example, `--jp-border-color1` should\nbe used with `--jp-layout-color1`. The numbers have the following meanings:\n\n* 0: super-primary, reserved for special emphasis\n* 1: primary, most important under normal situations\n* 2: secondary, next most important under normal situations\n* 3: tertiary, next most important under normal situations\n\nThroughout JupyterLab, we are mostly following principles from Google's\nMaterial Design when selecting colors. We are not, however, following\nall of MD as it is not optimized for dense, information rich UIs.\n*/\n\n\n/*\n * Optional monospace font for input/output prompt.\n */\n /* Commented out in ipywidgets since we don't need it. */\n/* @import url('https://fonts.googleapis.com/css?family=Roboto+Mono'); */\n\n/*\n * Added for compabitility with output area\n */\n:root {\n  --jp-icon-search: none;\n  --jp-ui-select-caret: none;\n}\n\n\n:root {\n\n  /* Borders\n\n  The following variables, specify the visual styling of borders in JupyterLab.\n   */\n\n  --jp-border-width: 1px;\n  --jp-border-color0: var(--md-grey-700);\n  --jp-border-color1: var(--md-grey-500);\n  --jp-border-color2: var(--md-grey-300);\n  --jp-border-color3: var(--md-grey-100);\n\n  /* UI Fonts\n\n  The UI font CSS variables are used for the typography all of the JupyterLab\n  user interface elements that are not directly user generated content.\n  */\n\n  --jp-ui-font-scale-factor: 1.2;\n  --jp-ui-font-size0: calc(var(--jp-ui-font-size1)/var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size1: 13px; /* Base font size */\n  --jp-ui-font-size2: calc(var(--jp-ui-font-size1)*var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size3: calc(var(--jp-ui-font-size2)*var(--jp-ui-font-scale-factor));\n  --jp-ui-icon-font-size: 14px; /* Ensures px perfect FontAwesome icons */\n  --jp-ui-font-family: \"Helvetica Neue\", Helvetica, Arial, sans-serif;\n\n  /* Use these font colors against the corresponding main layout colors.\n     In a light theme, these go from dark to light.\n  */\n\n  --jp-ui-font-color0: rgba(0,0,0,1.0);\n  --jp-ui-font-color1: rgba(0,0,0,0.8);\n  --jp-ui-font-color2: rgba(0,0,0,0.5);\n  --jp-ui-font-color3: rgba(0,0,0,0.3);\n\n  /* Use these against the brand/accent/warn/error colors.\n     These will typically go from light to darker, in both a dark and light theme\n   */\n\n  --jp-inverse-ui-font-color0: rgba(255,255,255,1);\n  --jp-inverse-ui-font-color1: rgba(255,255,255,1.0);\n  --jp-inverse-ui-font-color2: rgba(255,255,255,0.7);\n  --jp-inverse-ui-font-color3: rgba(255,255,255,0.5);\n\n  /* Content Fonts\n\n  Content font variables are used for typography of user generated content.\n  */\n\n  --jp-content-font-size: 13px;\n  --jp-content-line-height: 1.5;\n  --jp-content-font-color0: black;\n  --jp-content-font-color1: black;\n  --jp-content-font-color2: var(--md-grey-700);\n  --jp-content-font-color3: var(--md-grey-500);\n\n  --jp-ui-font-scale-factor: 1.2;\n  --jp-ui-font-size0: calc(var(--jp-ui-font-size1)/var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size1: 13px; /* Base font size */\n  --jp-ui-font-size2: calc(var(--jp-ui-font-size1)*var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size3: calc(var(--jp-ui-font-size2)*var(--jp-ui-font-scale-factor));\n\n  --jp-code-font-size: 13px;\n  --jp-code-line-height: 1.307;\n  --jp-code-padding: 5px;\n  --jp-code-font-family: monospace;\n\n\n  /* Layout\n\n  The following are the main layout colors use in JupyterLab. In a light\n  theme these would go from light to dark.\n  */\n\n  --jp-layout-color0: white;\n  --jp-layout-color1: white;\n  --jp-layout-color2: var(--md-grey-200);\n  --jp-layout-color3: var(--md-grey-400);\n\n  /* Brand/accent */\n\n  --jp-brand-color0: var(--md-blue-700);\n  --jp-brand-color1: var(--md-blue-500);\n  --jp-brand-color2: var(--md-blue-300);\n  --jp-brand-color3: var(--md-blue-100);\n\n  --jp-accent-color0: var(--md-green-700);\n  --jp-accent-color1: var(--md-green-500);\n  --jp-accent-color2: var(--md-green-300);\n  --jp-accent-color3: var(--md-green-100);\n\n  /* State colors (warn, error, success, info) */\n\n  --jp-warn-color0: var(--md-orange-700);\n  --jp-warn-color1: var(--md-orange-500);\n  --jp-warn-color2: var(--md-orange-300);\n  --jp-warn-color3: var(--md-orange-100);\n\n  --jp-error-color0: var(--md-red-700);\n  --jp-error-color1: var(--md-red-500);\n  --jp-error-color2: var(--md-red-300);\n  --jp-error-color3: var(--md-red-100);\n\n  --jp-success-color0: var(--md-green-700);\n  --jp-success-color1: var(--md-green-500);\n  --jp-success-color2: var(--md-green-300);\n  --jp-success-color3: var(--md-green-100);\n\n  --jp-info-color0: var(--md-cyan-700);\n  --jp-info-color1: var(--md-cyan-500);\n  --jp-info-color2: var(--md-cyan-300);\n  --jp-info-color3: var(--md-cyan-100);\n\n  /* Cell specific styles */\n\n  --jp-cell-padding: 5px;\n  --jp-cell-editor-background: #f7f7f7;\n  --jp-cell-editor-border-color: #cfcfcf;\n  --jp-cell-editor-background-edit: var(--jp-ui-layout-color1);\n  --jp-cell-editor-border-color-edit: var(--jp-brand-color1);\n  --jp-cell-prompt-width: 100px;\n  --jp-cell-prompt-font-family: 'Roboto Mono', monospace;\n  --jp-cell-prompt-letter-spacing: 0px;\n  --jp-cell-prompt-opacity: 1.0;\n  --jp-cell-prompt-opacity-not-active: 0.4;\n  --jp-cell-prompt-font-color-not-active: var(--md-grey-700);\n  /* A custom blend of MD grey and blue 600\n   * See https://meyerweb.com/eric/tools/color-blend/#546E7A:1E88E5:5:hex */\n  --jp-cell-inprompt-font-color: #307FC1;\n  /* A custom blend of MD grey and orange 600\n   * https://meyerweb.com/eric/tools/color-blend/#546E7A:F4511E:5:hex */\n  --jp-cell-outprompt-font-color: #BF5B3D;\n\n  /* Notebook specific styles */\n\n  --jp-notebook-padding: 10px;\n  --jp-notebook-scroll-padding: 100px;\n\n  /* Console specific styles */\n\n  --jp-console-background: var(--md-grey-100);\n\n  /* Toolbar specific styles */\n\n  --jp-toolbar-border-color: var(--md-grey-400);\n  --jp-toolbar-micro-height: 8px;\n  --jp-toolbar-background: var(--jp-layout-color0);\n  --jp-toolbar-box-shadow: 0px 0px 2px 0px rgba(0,0,0,0.24);\n  --jp-toolbar-header-margin: 4px 4px 0px 4px;\n  --jp-toolbar-active-background: var(--md-grey-300);\n}\n","/**\n * The material design colors are adapted from google-material-color v1.2.6\n * https://github.com/danlevan/google-material-color\n * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/dist/palette.var.css\n *\n * The license for the material design color CSS variables is as follows (see\n * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/LICENSE)\n *\n * The MIT License (MIT)\n *\n * Copyright (c) 2014 Dan Le Van\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to deal\n * in the Software without restriction, including without limitation the rights\n * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n * copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n * SOFTWARE.\n */\n:root {\n  --md-red-50: #FFEBEE;\n  --md-red-100: #FFCDD2;\n  --md-red-200: #EF9A9A;\n  --md-red-300: #E57373;\n  --md-red-400: #EF5350;\n  --md-red-500: #F44336;\n  --md-red-600: #E53935;\n  --md-red-700: #D32F2F;\n  --md-red-800: #C62828;\n  --md-red-900: #B71C1C;\n  --md-red-A100: #FF8A80;\n  --md-red-A200: #FF5252;\n  --md-red-A400: #FF1744;\n  --md-red-A700: #D50000;\n\n  --md-pink-50: #FCE4EC;\n  --md-pink-100: #F8BBD0;\n  --md-pink-200: #F48FB1;\n  --md-pink-300: #F06292;\n  --md-pink-400: #EC407A;\n  --md-pink-500: #E91E63;\n  --md-pink-600: #D81B60;\n  --md-pink-700: #C2185B;\n  --md-pink-800: #AD1457;\n  --md-pink-900: #880E4F;\n  --md-pink-A100: #FF80AB;\n  --md-pink-A200: #FF4081;\n  --md-pink-A400: #F50057;\n  --md-pink-A700: #C51162;\n\n  --md-purple-50: #F3E5F5;\n  --md-purple-100: #E1BEE7;\n  --md-purple-200: #CE93D8;\n  --md-purple-300: #BA68C8;\n  --md-purple-400: #AB47BC;\n  --md-purple-500: #9C27B0;\n  --md-purple-600: #8E24AA;\n  --md-purple-700: #7B1FA2;\n  --md-purple-800: #6A1B9A;\n  --md-purple-900: #4A148C;\n  --md-purple-A100: #EA80FC;\n  --md-purple-A200: #E040FB;\n  --md-purple-A400: #D500F9;\n  --md-purple-A700: #AA00FF;\n\n  --md-deep-purple-50: #EDE7F6;\n  --md-deep-purple-100: #D1C4E9;\n  --md-deep-purple-200: #B39DDB;\n  --md-deep-purple-300: #9575CD;\n  --md-deep-purple-400: #7E57C2;\n  --md-deep-purple-500: #673AB7;\n  --md-deep-purple-600: #5E35B1;\n  --md-deep-purple-700: #512DA8;\n  --md-deep-purple-800: #4527A0;\n  --md-deep-purple-900: #311B92;\n  --md-deep-purple-A100: #B388FF;\n  --md-deep-purple-A200: #7C4DFF;\n  --md-deep-purple-A400: #651FFF;\n  --md-deep-purple-A700: #6200EA;\n\n  --md-indigo-50: #E8EAF6;\n  --md-indigo-100: #C5CAE9;\n  --md-indigo-200: #9FA8DA;\n  --md-indigo-300: #7986CB;\n  --md-indigo-400: #5C6BC0;\n  --md-indigo-500: #3F51B5;\n  --md-indigo-600: #3949AB;\n  --md-indigo-700: #303F9F;\n  --md-indigo-800: #283593;\n  --md-indigo-900: #1A237E;\n  --md-indigo-A100: #8C9EFF;\n  --md-indigo-A200: #536DFE;\n  --md-indigo-A400: #3D5AFE;\n  --md-indigo-A700: #304FFE;\n\n  --md-blue-50: #E3F2FD;\n  --md-blue-100: #BBDEFB;\n  --md-blue-200: #90CAF9;\n  --md-blue-300: #64B5F6;\n  --md-blue-400: #42A5F5;\n  --md-blue-500: #2196F3;\n  --md-blue-600: #1E88E5;\n  --md-blue-700: #1976D2;\n  --md-blue-800: #1565C0;\n  --md-blue-900: #0D47A1;\n  --md-blue-A100: #82B1FF;\n  --md-blue-A200: #448AFF;\n  --md-blue-A400: #2979FF;\n  --md-blue-A700: #2962FF;\n\n  --md-light-blue-50: #E1F5FE;\n  --md-light-blue-100: #B3E5FC;\n  --md-light-blue-200: #81D4FA;\n  --md-light-blue-300: #4FC3F7;\n  --md-light-blue-400: #29B6F6;\n  --md-light-blue-500: #03A9F4;\n  --md-light-blue-600: #039BE5;\n  --md-light-blue-700: #0288D1;\n  --md-light-blue-800: #0277BD;\n  --md-light-blue-900: #01579B;\n  --md-light-blue-A100: #80D8FF;\n  --md-light-blue-A200: #40C4FF;\n  --md-light-blue-A400: #00B0FF;\n  --md-light-blue-A700: #0091EA;\n\n  --md-cyan-50: #E0F7FA;\n  --md-cyan-100: #B2EBF2;\n  --md-cyan-200: #80DEEA;\n  --md-cyan-300: #4DD0E1;\n  --md-cyan-400: #26C6DA;\n  --md-cyan-500: #00BCD4;\n  --md-cyan-600: #00ACC1;\n  --md-cyan-700: #0097A7;\n  --md-cyan-800: #00838F;\n  --md-cyan-900: #006064;\n  --md-cyan-A100: #84FFFF;\n  --md-cyan-A200: #18FFFF;\n  --md-cyan-A400: #00E5FF;\n  --md-cyan-A700: #00B8D4;\n\n  --md-teal-50: #E0F2F1;\n  --md-teal-100: #B2DFDB;\n  --md-teal-200: #80CBC4;\n  --md-teal-300: #4DB6AC;\n  --md-teal-400: #26A69A;\n  --md-teal-500: #009688;\n  --md-teal-600: #00897B;\n  --md-teal-700: #00796B;\n  --md-teal-800: #00695C;\n  --md-teal-900: #004D40;\n  --md-teal-A100: #A7FFEB;\n  --md-teal-A200: #64FFDA;\n  --md-teal-A400: #1DE9B6;\n  --md-teal-A700: #00BFA5;\n\n  --md-green-50: #E8F5E9;\n  --md-green-100: #C8E6C9;\n  --md-green-200: #A5D6A7;\n  --md-green-300: #81C784;\n  --md-green-400: #66BB6A;\n  --md-green-500: #4CAF50;\n  --md-green-600: #43A047;\n  --md-green-700: #388E3C;\n  --md-green-800: #2E7D32;\n  --md-green-900: #1B5E20;\n  --md-green-A100: #B9F6CA;\n  --md-green-A200: #69F0AE;\n  --md-green-A400: #00E676;\n  --md-green-A700: #00C853;\n\n  --md-light-green-50: #F1F8E9;\n  --md-light-green-100: #DCEDC8;\n  --md-light-green-200: #C5E1A5;\n  --md-light-green-300: #AED581;\n  --md-light-green-400: #9CCC65;\n  --md-light-green-500: #8BC34A;\n  --md-light-green-600: #7CB342;\n  --md-light-green-700: #689F38;\n  --md-light-green-800: #558B2F;\n  --md-light-green-900: #33691E;\n  --md-light-green-A100: #CCFF90;\n  --md-light-green-A200: #B2FF59;\n  --md-light-green-A400: #76FF03;\n  --md-light-green-A700: #64DD17;\n\n  --md-lime-50: #F9FBE7;\n  --md-lime-100: #F0F4C3;\n  --md-lime-200: #E6EE9C;\n  --md-lime-300: #DCE775;\n  --md-lime-400: #D4E157;\n  --md-lime-500: #CDDC39;\n  --md-lime-600: #C0CA33;\n  --md-lime-700: #AFB42B;\n  --md-lime-800: #9E9D24;\n  --md-lime-900: #827717;\n  --md-lime-A100: #F4FF81;\n  --md-lime-A200: #EEFF41;\n  --md-lime-A400: #C6FF00;\n  --md-lime-A700: #AEEA00;\n\n  --md-yellow-50: #FFFDE7;\n  --md-yellow-100: #FFF9C4;\n  --md-yellow-200: #FFF59D;\n  --md-yellow-300: #FFF176;\n  --md-yellow-400: #FFEE58;\n  --md-yellow-500: #FFEB3B;\n  --md-yellow-600: #FDD835;\n  --md-yellow-700: #FBC02D;\n  --md-yellow-800: #F9A825;\n  --md-yellow-900: #F57F17;\n  --md-yellow-A100: #FFFF8D;\n  --md-yellow-A200: #FFFF00;\n  --md-yellow-A400: #FFEA00;\n  --md-yellow-A700: #FFD600;\n\n  --md-amber-50: #FFF8E1;\n  --md-amber-100: #FFECB3;\n  --md-amber-200: #FFE082;\n  --md-amber-300: #FFD54F;\n  --md-amber-400: #FFCA28;\n  --md-amber-500: #FFC107;\n  --md-amber-600: #FFB300;\n  --md-amber-700: #FFA000;\n  --md-amber-800: #FF8F00;\n  --md-amber-900: #FF6F00;\n  --md-amber-A100: #FFE57F;\n  --md-amber-A200: #FFD740;\n  --md-amber-A400: #FFC400;\n  --md-amber-A700: #FFAB00;\n\n  --md-orange-50: #FFF3E0;\n  --md-orange-100: #FFE0B2;\n  --md-orange-200: #FFCC80;\n  --md-orange-300: #FFB74D;\n  --md-orange-400: #FFA726;\n  --md-orange-500: #FF9800;\n  --md-orange-600: #FB8C00;\n  --md-orange-700: #F57C00;\n  --md-orange-800: #EF6C00;\n  --md-orange-900: #E65100;\n  --md-orange-A100: #FFD180;\n  --md-orange-A200: #FFAB40;\n  --md-orange-A400: #FF9100;\n  --md-orange-A700: #FF6D00;\n\n  --md-deep-orange-50: #FBE9E7;\n  --md-deep-orange-100: #FFCCBC;\n  --md-deep-orange-200: #FFAB91;\n  --md-deep-orange-300: #FF8A65;\n  --md-deep-orange-400: #FF7043;\n  --md-deep-orange-500: #FF5722;\n  --md-deep-orange-600: #F4511E;\n  --md-deep-orange-700: #E64A19;\n  --md-deep-orange-800: #D84315;\n  --md-deep-orange-900: #BF360C;\n  --md-deep-orange-A100: #FF9E80;\n  --md-deep-orange-A200: #FF6E40;\n  --md-deep-orange-A400: #FF3D00;\n  --md-deep-orange-A700: #DD2C00;\n\n  --md-brown-50: #EFEBE9;\n  --md-brown-100: #D7CCC8;\n  --md-brown-200: #BCAAA4;\n  --md-brown-300: #A1887F;\n  --md-brown-400: #8D6E63;\n  --md-brown-500: #795548;\n  --md-brown-600: #6D4C41;\n  --md-brown-700: #5D4037;\n  --md-brown-800: #4E342E;\n  --md-brown-900: #3E2723;\n\n  --md-grey-50: #FAFAFA;\n  --md-grey-100: #F5F5F5;\n  --md-grey-200: #EEEEEE;\n  --md-grey-300: #E0E0E0;\n  --md-grey-400: #BDBDBD;\n  --md-grey-500: #9E9E9E;\n  --md-grey-600: #757575;\n  --md-grey-700: #616161;\n  --md-grey-800: #424242;\n  --md-grey-900: #212121;\n\n  --md-blue-grey-50: #ECEFF1;\n  --md-blue-grey-100: #CFD8DC;\n  --md-blue-grey-200: #B0BEC5;\n  --md-blue-grey-300: #90A4AE;\n  --md-blue-grey-400: #78909C;\n  --md-blue-grey-500: #607D8B;\n  --md-blue-grey-600: #546E7A;\n  --md-blue-grey-700: #455A64;\n  --md-blue-grey-800: #37474F;\n  --md-blue-grey-900: #263238;\n}","/* Copyright (c) Jupyter Development Team.\n * Distributed under the terms of the Modified BSD License.\n */\n\n/*\n * We assume that the CSS variables in\n * https://github.com/jupyterlab/jupyterlab/blob/master/src/default-theme/variables.css\n * have been defined.\n */\n\n@import \"./phosphor.css\";\n\n:root {\n    --jp-widgets-color: var(--jp-content-font-color1);\n    --jp-widgets-label-color: var(--jp-widgets-color);\n    --jp-widgets-readout-color: var(--jp-widgets-color);\n    --jp-widgets-font-size: var(--jp-ui-font-size1);\n    --jp-widgets-margin: 2px;\n    --jp-widgets-inline-height: 28px;\n    --jp-widgets-inline-width: 300px;\n    --jp-widgets-inline-width-short: calc(var(--jp-widgets-inline-width) / 2 - var(--jp-widgets-margin));\n    --jp-widgets-inline-width-tiny: calc(var(--jp-widgets-inline-width-short) / 2 - var(--jp-widgets-margin));\n    --jp-widgets-inline-margin: 4px; /* margin between inline elements */\n    --jp-widgets-inline-label-width: 80px;\n    --jp-widgets-border-width: var(--jp-border-width);\n    --jp-widgets-vertical-height: 200px;\n    --jp-widgets-horizontal-tab-height: 24px;\n    --jp-widgets-horizontal-tab-width: 144px;\n    --jp-widgets-horizontal-tab-top-border: 2px;\n    --jp-widgets-progress-thickness: 20px;\n    --jp-widgets-container-padding: 15px;\n    --jp-widgets-input-padding: 4px;\n    --jp-widgets-radio-item-height-adjustment: 8px;\n    --jp-widgets-radio-item-height: calc(var(--jp-widgets-inline-height) - var(--jp-widgets-radio-item-height-adjustment));\n    --jp-widgets-slider-track-thickness: 4px;\n    --jp-widgets-slider-border-width: var(--jp-widgets-border-width);\n    --jp-widgets-slider-handle-size: 16px;\n    --jp-widgets-slider-handle-border-color: var(--jp-border-color1);\n    --jp-widgets-slider-handle-background-color: var(--jp-layout-color1);\n    --jp-widgets-slider-active-handle-color: var(--jp-brand-color1);\n    --jp-widgets-menu-item-height: 24px;\n    --jp-widgets-dropdown-arrow: url(\"data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0idXRmLTgiPz4KPCEtLSBHZW5lcmF0b3I6IEFkb2JlIElsbHVzdHJhdG9yIDE5LjIuMSwgU1ZHIEV4cG9ydCBQbHVnLUluIC4gU1ZHIFZlcnNpb246IDYuMDAgQnVpbGQgMCkgIC0tPgo8c3ZnIHZlcnNpb249IjEuMSIgaWQ9IkxheWVyXzEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgeG1sbnM6eGxpbms9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkveGxpbmsiIHg9IjBweCIgeT0iMHB4IgoJIHZpZXdCb3g9IjAgMCAxOCAxOCIgc3R5bGU9ImVuYWJsZS1iYWNrZ3JvdW5kOm5ldyAwIDAgMTggMTg7IiB4bWw6c3BhY2U9InByZXNlcnZlIj4KPHN0eWxlIHR5cGU9InRleHQvY3NzIj4KCS5zdDB7ZmlsbDpub25lO30KPC9zdHlsZT4KPHBhdGggZD0iTTUuMiw1LjlMOSw5LjdsMy44LTMuOGwxLjIsMS4ybC00LjksNWwtNC45LTVMNS4yLDUuOXoiLz4KPHBhdGggY2xhc3M9InN0MCIgZD0iTTAtMC42aDE4djE4SDBWLTAuNnoiLz4KPC9zdmc+Cg\");\n    --jp-widgets-input-color: var(--jp-ui-font-color1);\n    --jp-widgets-input-background-color: var(--jp-layout-color1);\n    --jp-widgets-input-border-color: var(--jp-border-color1);\n    --jp-widgets-input-focus-border-color: var(--jp-brand-color2);\n    --jp-widgets-input-border-width: var(--jp-widgets-border-width);\n    --jp-widgets-disabled-opacity: 0.6;\n\n    /* From Material Design Lite */\n    --md-shadow-key-umbra-opacity: 0.2;\n    --md-shadow-key-penumbra-opacity: 0.14;\n    --md-shadow-ambient-shadow-opacity: 0.12;\n}\n\n.jupyter-widgets {\n    margin: var(--jp-widgets-margin);\n    box-sizing: border-box;\n    color: var(--jp-widgets-color);\n    overflow: visible;\n}\n\n.jupyter-widgets.jupyter-widgets-disconnected::before {\n    line-height: var(--jp-widgets-inline-height);\n    height: var(--jp-widgets-inline-height);\n}\n\n.jp-Output-result > .jupyter-widgets {\n    margin-left: 0;\n    margin-right: 0;\n}\n\n/* vbox and hbox */\n\n.widget-inline-hbox {\n    /* Horizontal widgets */\n    box-sizing: border-box;\n    display: flex;\n    flex-direction: row;\n    align-items: baseline;\n}\n\n.widget-inline-vbox {\n    /* Vertical Widgets */\n    box-sizing: border-box;\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n}\n\n.widget-box {\n    box-sizing: border-box;\n    display: flex;\n    margin: 0;\n    overflow: auto;\n}\n\n.widget-gridbox {\n    box-sizing: border-box;\n    display: grid;\n    margin: 0;\n    overflow: auto;\n}\n\n.widget-hbox {\n    flex-direction: row;\n}\n\n.widget-vbox {\n    flex-direction: column;\n}\n\n/* General Button Styling */\n\n.jupyter-button {\n    padding-left: 10px;\n    padding-right: 10px;\n    padding-top: 0px;\n    padding-bottom: 0px;\n    display: inline-block;\n    white-space: nowrap;\n    overflow: hidden;\n    text-overflow: ellipsis;\n    text-align: center;\n    font-size: var(--jp-widgets-font-size);\n    cursor: pointer;\n\n    height: var(--jp-widgets-inline-height);\n    border: 0px solid;\n    line-height: var(--jp-widgets-inline-height);\n    box-shadow: none;\n\n    color: var(--jp-ui-font-color1);\n    background-color: var(--jp-layout-color2);\n    border-color: var(--jp-border-color2);\n    border: none;\n}\n\n.jupyter-button i.fa {\n    margin-right: var(--jp-widgets-inline-margin);\n    pointer-events: none;\n}\n\n.jupyter-button:empty:before {\n    content: \"\\200b\"; /* zero-width space */\n}\n\n.jupyter-widgets.jupyter-button:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n.jupyter-button i.fa.center {\n    margin-right: 0;\n}\n\n.jupyter-button:hover:enabled, .jupyter-button:focus:enabled {\n    /* MD Lite 2dp shadow */\n    box-shadow: 0 2px 2px 0 rgba(0, 0, 0, var(--md-shadow-key-penumbra-opacity)),\n                0 3px 1px -2px rgba(0, 0, 0, var(--md-shadow-key-umbra-opacity)),\n                0 1px 5px 0 rgba(0, 0, 0, var(--md-shadow-ambient-shadow-opacity));\n}\n\n.jupyter-button:active, .jupyter-button.mod-active {\n    /* MD Lite 4dp shadow */\n    box-shadow: 0 4px 5px 0 rgba(0, 0, 0, var(--md-shadow-key-penumbra-opacity)),\n                0 1px 10px 0 rgba(0, 0, 0, var(--md-shadow-ambient-shadow-opacity)),\n                0 2px 4px -1px rgba(0, 0, 0, var(--md-shadow-key-umbra-opacity));\n    color: var(--jp-ui-font-color1);\n    background-color: var(--jp-layout-color3);\n}\n\n.jupyter-button:focus:enabled {\n    outline: 1px solid var(--jp-widgets-input-focus-border-color);\n}\n\n/* Button \"Primary\" Styling */\n\n.jupyter-button.mod-primary {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-brand-color1);\n}\n\n.jupyter-button.mod-primary.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-brand-color0);\n}\n\n.jupyter-button.mod-primary:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-brand-color0);\n}\n\n/* Button \"Success\" Styling */\n\n.jupyter-button.mod-success {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-success-color1);\n}\n\n.jupyter-button.mod-success.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-success-color0);\n }\n\n.jupyter-button.mod-success:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-success-color0);\n }\n\n /* Button \"Info\" Styling */\n\n.jupyter-button.mod-info {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-info-color1);\n}\n\n.jupyter-button.mod-info.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-info-color0);\n}\n\n.jupyter-button.mod-info:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-info-color0);\n}\n\n/* Button \"Warning\" Styling */\n\n.jupyter-button.mod-warning {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-warn-color1);\n}\n\n.jupyter-button.mod-warning.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-warn-color0);\n}\n\n.jupyter-button.mod-warning:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-warn-color0);\n}\n\n/* Button \"Danger\" Styling */\n\n.jupyter-button.mod-danger {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-error-color1);\n}\n\n.jupyter-button.mod-danger.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-error-color0);\n}\n\n.jupyter-button.mod-danger:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-error-color0);\n}\n\n/* Widget Button*/\n\n.widget-button, .widget-toggle-button {\n    width: var(--jp-widgets-inline-width-short);\n}\n\n/* Widget Label Styling */\n\n/* Override Bootstrap label css */\n.jupyter-widgets label {\n    margin-bottom: initial;\n}\n\n.widget-label-basic {\n    /* Basic Label */\n    color: var(--jp-widgets-label-color);\n    font-size: var(--jp-widgets-font-size);\n    overflow: hidden;\n    text-overflow: ellipsis;\n    white-space: nowrap;\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-label {\n    /* Label */\n    color: var(--jp-widgets-label-color);\n    font-size: var(--jp-widgets-font-size);\n    overflow: hidden;\n    text-overflow: ellipsis;\n    white-space: nowrap;\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-inline-hbox .widget-label {\n    /* Horizontal Widget Label */\n    color: var(--jp-widgets-label-color);\n    text-align: right;\n    margin-right: calc( var(--jp-widgets-inline-margin) * 2 );\n    width: var(--jp-widgets-inline-label-width);\n    flex-shrink: 0;\n}\n\n.widget-inline-vbox .widget-label {\n    /* Vertical Widget Label */\n    color: var(--jp-widgets-label-color);\n    text-align: center;\n    line-height: var(--jp-widgets-inline-height);\n}\n\n/* Widget Readout Styling */\n\n.widget-readout {\n    color: var(--jp-widgets-readout-color);\n    font-size: var(--jp-widgets-font-size);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    overflow: hidden;\n    white-space: nowrap;\n    text-align: center;\n}\n\n.widget-readout.overflow {\n    /* Overflowing Readout */\n\n    /* From Material Design Lite\n        shadow-key-umbra-opacity: 0.2;\n        shadow-key-penumbra-opacity: 0.14;\n        shadow-ambient-shadow-opacity: 0.12;\n     */\n    -webkit-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.2),\n                        0 3px 1px -2px rgba(0, 0, 0, 0.14),\n                        0 1px 5px 0 rgba(0, 0, 0, 0.12);\n\n    -moz-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.2),\n                     0 3px 1px -2px rgba(0, 0, 0, 0.14),\n                     0 1px 5px 0 rgba(0, 0, 0, 0.12);\n\n    box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.2),\n                0 3px 1px -2px rgba(0, 0, 0, 0.14),\n                0 1px 5px 0 rgba(0, 0, 0, 0.12);\n}\n\n.widget-inline-hbox .widget-readout {\n    /* Horizontal Readout */\n    text-align: center;\n    max-width: var(--jp-widgets-inline-width-short);\n    min-width: var(--jp-widgets-inline-width-tiny);\n    margin-left: var(--jp-widgets-inline-margin);\n}\n\n.widget-inline-vbox .widget-readout {\n    /* Vertical Readout */\n    margin-top: var(--jp-widgets-inline-margin);\n    /* as wide as the widget */\n    width: inherit;\n}\n\n/* Widget Checkbox Styling */\n\n.widget-checkbox {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-checkbox input[type=\"checkbox\"] {\n    margin: 0px calc( var(--jp-widgets-inline-margin) * 2 ) 0px 0px;\n    line-height: var(--jp-widgets-inline-height);\n    font-size: large;\n    flex-grow: 1;\n    flex-shrink: 0;\n    align-self: center;\n}\n\n/* Widget Valid Styling */\n\n.widget-valid {\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    width: var(--jp-widgets-inline-width-short);\n    font-size: var(--jp-widgets-font-size);\n}\n\n.widget-valid i:before {\n    line-height: var(--jp-widgets-inline-height);\n    margin-right: var(--jp-widgets-inline-margin);\n    margin-left: var(--jp-widgets-inline-margin);\n\n    /* from the fa class in FontAwesome: https://github.com/FortAwesome/Font-Awesome/blob/49100c7c3a7b58d50baa71efef11af41a66b03d3/css/font-awesome.css#L14 */\n    display: inline-block;\n    font: normal normal normal 14px/1 FontAwesome;\n    font-size: inherit;\n    text-rendering: auto;\n    -webkit-font-smoothing: antialiased;\n    -moz-osx-font-smoothing: grayscale;\n}\n\n.widget-valid.mod-valid i:before {\n    content: \"\\f00c\";\n    color: green;\n}\n\n.widget-valid.mod-invalid i:before {\n    content: \"\\f00d\";\n    color: red;\n}\n\n.widget-valid.mod-valid .widget-valid-readout {\n    display: none;\n}\n\n/* Widget Text and TextArea Stying */\n\n.widget-textarea, .widget-text {\n    width: var(--jp-widgets-inline-width);\n}\n\n.widget-text input[type=\"text\"], .widget-text input[type=\"number\"]{\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-text input[type=\"text\"]:disabled, .widget-text input[type=\"number\"]:disabled, .widget-textarea textarea:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n.widget-text input[type=\"text\"], .widget-text input[type=\"number\"], .widget-textarea textarea {\n    box-sizing: border-box;\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    padding: var(--jp-widgets-input-padding) calc( var(--jp-widgets-input-padding) *  2 );\n    flex-grow: 1;\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    flex-shrink: 1;\n    outline: none !important;\n}\n\n.widget-textarea textarea {\n    height: inherit;\n    width: inherit;\n}\n\n.widget-text input:focus, .widget-textarea textarea:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n/* Widget Slider */\n\n.widget-slider .ui-slider {\n    /* Slider Track */\n    border: var(--jp-widgets-slider-border-width) solid var(--jp-layout-color3);\n    background: var(--jp-layout-color3);\n    box-sizing: border-box;\n    position: relative;\n    border-radius: 0px;\n}\n\n.widget-slider .ui-slider .ui-slider-handle {\n    /* Slider Handle */\n    outline: none !important; /* focused slider handles are colored - see below */\n    position: absolute;\n    background-color: var(--jp-widgets-slider-handle-background-color);\n    border: var(--jp-widgets-slider-border-width) solid var(--jp-widgets-slider-handle-border-color);\n    box-sizing: border-box;\n    z-index: 1;\n    background-image: none; /* Override jquery-ui */\n}\n\n/* Override jquery-ui */\n.widget-slider .ui-slider .ui-slider-handle:hover, .widget-slider .ui-slider .ui-slider-handle:focus {\n    background-color: var(--jp-widgets-slider-active-handle-color);\n    border: var(--jp-widgets-slider-border-width) solid var(--jp-widgets-slider-active-handle-color);\n}\n\n.widget-slider .ui-slider .ui-slider-handle:active {\n    background-color: var(--jp-widgets-slider-active-handle-color);\n    border-color: var(--jp-widgets-slider-active-handle-color);\n    z-index: 2;\n    transform: scale(1.2);\n}\n\n.widget-slider  .ui-slider .ui-slider-range {\n    /* Interval between the two specified value of a double slider */\n    position: absolute;\n    background: var(--jp-widgets-slider-active-handle-color);\n    z-index: 0;\n}\n\n/* Shapes of Slider Handles */\n\n.widget-hslider .ui-slider .ui-slider-handle {\n    width: var(--jp-widgets-slider-handle-size);\n    height: var(--jp-widgets-slider-handle-size);\n    margin-top: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-handle-size)) / 2 - var(--jp-widgets-slider-border-width));\n    margin-left: calc(var(--jp-widgets-slider-handle-size) / -2 + var(--jp-widgets-slider-border-width));\n    border-radius: 50%;\n    top: 0;\n}\n\n.widget-vslider .ui-slider .ui-slider-handle {\n    width: var(--jp-widgets-slider-handle-size);\n    height: var(--jp-widgets-slider-handle-size);\n    margin-bottom: calc(var(--jp-widgets-slider-handle-size) / -2 + var(--jp-widgets-slider-border-width));\n    margin-left: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-handle-size)) / 2 - var(--jp-widgets-slider-border-width));\n    border-radius: 50%;\n    left: 0;\n}\n\n.widget-hslider .ui-slider .ui-slider-range {\n    height: calc( var(--jp-widgets-slider-track-thickness) * 2 );\n    margin-top: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-track-thickness) * 2 ) / 2 - var(--jp-widgets-slider-border-width));\n}\n\n.widget-vslider .ui-slider .ui-slider-range {\n    width: calc( var(--jp-widgets-slider-track-thickness) * 2 );\n    margin-left: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-track-thickness) * 2 ) / 2 - var(--jp-widgets-slider-border-width));\n}\n\n/* Horizontal Slider */\n\n.widget-hslider {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n\n    /* Override the align-items baseline. This way, the description and readout\n    still seem to align their baseline properly, and we don't have to have\n    align-self: stretch in the .slider-container. */\n    align-items: center;\n}\n\n.widgets-slider .slider-container {\n    overflow: visible;\n}\n\n.widget-hslider .slider-container {\n    height: var(--jp-widgets-inline-height);\n    margin-left: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    margin-right: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n}\n\n.widget-hslider .ui-slider {\n    /* Inner, invisible slide div */\n    height: var(--jp-widgets-slider-track-thickness);\n    margin-top: calc((var(--jp-widgets-inline-height) - var(--jp-widgets-slider-track-thickness)) / 2);\n    width: 100%;\n}\n\n/* Vertical Slider */\n\n.widget-vbox .widget-label {\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-vslider {\n    /* Vertical Slider */\n    height: var(--jp-widgets-vertical-height);\n    width: var(--jp-widgets-inline-width-tiny);\n}\n\n.widget-vslider .slider-container {\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n    margin-left: auto;\n    margin-right: auto;\n    margin-bottom: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    margin-top: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    display: flex;\n    flex-direction: column;\n}\n\n.widget-vslider .ui-slider-vertical {\n    /* Inner, invisible slide div */\n    width: var(--jp-widgets-slider-track-thickness);\n    flex-grow: 1;\n    margin-left: auto;\n    margin-right: auto;\n}\n\n/* Widget Progress Styling */\n\n.progress-bar {\n    -webkit-transition: none;\n    -moz-transition: none;\n    -ms-transition: none;\n    -o-transition: none;\n    transition: none;\n}\n\n.progress-bar {\n    height: var(--jp-widgets-inline-height);\n}\n\n.progress-bar {\n    background-color: var(--jp-brand-color1);\n}\n\n.progress-bar-success {\n    background-color: var(--jp-success-color1);\n}\n\n.progress-bar-info {\n    background-color: var(--jp-info-color1);\n}\n\n.progress-bar-warning {\n    background-color: var(--jp-warn-color1);\n}\n\n.progress-bar-danger {\n    background-color: var(--jp-error-color1);\n}\n\n.progress {\n    background-color: var(--jp-layout-color2);\n    border: none;\n    box-shadow: none;\n}\n\n/* Horisontal Progress */\n\n.widget-hprogress {\n    /* Progress Bar */\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    width: var(--jp-widgets-inline-width);\n    align-items: center;\n\n}\n\n.widget-hprogress .progress {\n    flex-grow: 1;\n    margin-top: var(--jp-widgets-input-padding);\n    margin-bottom: var(--jp-widgets-input-padding);\n    align-self: stretch;\n    /* Override bootstrap style */\n    height: initial;\n}\n\n/* Vertical Progress */\n\n.widget-vprogress {\n    height: var(--jp-widgets-vertical-height);\n    width: var(--jp-widgets-inline-width-tiny);\n}\n\n.widget-vprogress .progress {\n    flex-grow: 1;\n    width: var(--jp-widgets-progress-thickness);\n    margin-left: auto;\n    margin-right: auto;\n    margin-bottom: 0;\n}\n\n/* Select Widget Styling */\n\n.widget-dropdown {\n    height: var(--jp-widgets-inline-height);\n    width: var(--jp-widgets-inline-width);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-dropdown > select {\n    padding-right: 20px;\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    border-radius: 0;\n    height: inherit;\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    box-sizing: border-box;\n    outline: none !important;\n    box-shadow: none;\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    vertical-align: top;\n    padding-left: calc( var(--jp-widgets-input-padding) * 2);\n\tappearance: none;\n\t-webkit-appearance: none;\n\t-moz-appearance: none;\n    background-repeat: no-repeat;\n\tbackground-size: 20px;\n\tbackground-position: right center;\n    background-image: var(--jp-widgets-dropdown-arrow);\n}\n.widget-dropdown > select:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.widget-dropdown > select:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* To disable the dotted border in Firefox around select controls.\n   See http://stackoverflow.com/a/18853002 */\n.widget-dropdown > select:-moz-focusring {\n    color: transparent;\n    text-shadow: 0 0 0 #000;\n}\n\n/* Select and SelectMultiple */\n\n.widget-select {\n    width: var(--jp-widgets-inline-width);\n    line-height: var(--jp-widgets-inline-height);\n\n    /* Because Firefox defines the baseline of a select as the bottom of the\n    control, we align the entire control to the top and add padding to the\n    select to get an approximate first line baseline alignment. */\n    align-items: flex-start;\n}\n\n.widget-select > select {\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n    outline: none !important;\n    overflow: auto;\n    height: inherit;\n\n    /* Because Firefox defines the baseline of a select as the bottom of the\n    control, we align the entire control to the top and add padding to the\n    select to get an approximate first line baseline alignment. */\n    padding-top: 5px;\n}\n\n.widget-select > select:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.wiget-select > select > option {\n    padding-left: var(--jp-widgets-input-padding);\n    line-height: var(--jp-widgets-inline-height);\n    /* line-height doesn't work on some browsers for select options */\n    padding-top: calc(var(--jp-widgets-inline-height)-var(--jp-widgets-font-size)/2);\n    padding-bottom: calc(var(--jp-widgets-inline-height)-var(--jp-widgets-font-size)/2);\n}\n\n\n\n/* Toggle Buttons Styling */\n\n.widget-toggle-buttons {\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-toggle-buttons .widget-toggle-button {\n    margin-left: var(--jp-widgets-margin);\n    margin-right: var(--jp-widgets-margin);\n}\n\n.widget-toggle-buttons .jupyter-button:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Radio Buttons Styling */\n\n.widget-radio {\n    width: var(--jp-widgets-inline-width);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-radio-box {\n    display: flex;\n    flex-direction: column;\n    align-items: stretch;\n    box-sizing: border-box;\n    flex-grow: 1;\n    margin-bottom: var(--jp-widgets-radio-item-height-adjustment);\n}\n\n.widget-radio-box label {\n    height: var(--jp-widgets-radio-item-height);\n    line-height: var(--jp-widgets-radio-item-height);\n    font-size: var(--jp-widgets-font-size);\n}\n\n.widget-radio-box input {\n    height: var(--jp-widgets-radio-item-height);\n    line-height: var(--jp-widgets-radio-item-height);\n    margin: 0 calc( var(--jp-widgets-input-padding) * 2 ) 0 1px;\n    float: left;\n}\n\n/* Color Picker Styling */\n\n.widget-colorpicker {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-colorpicker > .widget-colorpicker-input {\n    flex-grow: 1;\n    flex-shrink: 1;\n    min-width: var(--jp-widgets-inline-width-tiny);\n}\n\n.widget-colorpicker input[type=\"color\"] {\n    width: var(--jp-widgets-inline-height);\n    height: var(--jp-widgets-inline-height);\n    padding: 0 2px; /* make the color square actually square on Chrome on OS X */\n    background: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    border-left: none;\n    flex-grow: 0;\n    flex-shrink: 0;\n    box-sizing: border-box;\n    align-self: stretch;\n    outline: none !important;\n}\n\n.widget-colorpicker.concise input[type=\"color\"] {\n    border-left: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n}\n\n.widget-colorpicker input[type=\"color\"]:focus, .widget-colorpicker input[type=\"text\"]:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.widget-colorpicker input[type=\"text\"] {\n    flex-grow: 1;\n    outline: none !important;\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    background: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    font-size: var(--jp-widgets-font-size);\n    padding: var(--jp-widgets-input-padding) calc( var(--jp-widgets-input-padding) *  2 );\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    flex-shrink: 1;\n    box-sizing: border-box;\n}\n\n.widget-colorpicker input[type=\"text\"]:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Date Picker Styling */\n\n.widget-datepicker {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-datepicker input[type=\"date\"] {\n    flex-grow: 1;\n    flex-shrink: 1;\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    outline: none !important;\n    height: var(--jp-widgets-inline-height);\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    padding: var(--jp-widgets-input-padding) calc( var(--jp-widgets-input-padding) *  2 );\n    box-sizing: border-box;\n}\n\n.widget-datepicker input[type=\"date\"]:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.widget-datepicker input[type=\"date\"]:invalid {\n    border-color: var(--jp-warn-color1);\n}\n\n.widget-datepicker input[type=\"date\"]:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Play Widget */\n\n.widget-play {\n    width: var(--jp-widgets-inline-width-short);\n    display: flex;\n    align-items: stretch;\n}\n\n.widget-play .jupyter-button {\n    flex-grow: 1;\n    height: auto;\n}\n\n.widget-play .jupyter-button:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Tab Widget */\n\n.jupyter-widgets.widget-tab {\n    display: flex;\n    flex-direction: column;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar {\n    /* Necessary so that a tab can be shifted down to overlay the border of the box below. */\n    overflow-x: visible;\n    overflow-y: visible;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {\n    /* Make sure that the tab grows from bottom up */\n    align-items: flex-end;\n    min-width: 0;\n    min-height: 0;\n}\n\n.jupyter-widgets.widget-tab > .widget-tab-contents {\n    width: 100%;\n    box-sizing: border-box;\n    margin: 0;\n    background: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n    border: var(--jp-border-width) solid var(--jp-border-color1);\n    padding: var(--jp-widgets-container-padding);\n    flex-grow: 1;\n    overflow: auto;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar {\n    font: var(--jp-widgets-font-size) Helvetica, Arial, sans-serif;\n    min-height: calc(var(--jp-widgets-horizontal-tab-height) + var(--jp-border-width));\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {\n    flex: 0 1 var(--jp-widgets-horizontal-tab-width);\n    min-width: 35px;\n    min-height: calc(var(--jp-widgets-horizontal-tab-height) + var(--jp-border-width));\n    line-height: var(--jp-widgets-horizontal-tab-height);\n    margin-left: calc(-1 * var(--jp-border-width));\n    padding: 0px 10px;\n    background: var(--jp-layout-color2);\n    color: var(--jp-ui-font-color2);\n    border: var(--jp-border-width) solid var(--jp-border-color1);\n    border-bottom: none;\n    position: relative;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current {\n    color: var(--jp-ui-font-color0);\n    /* We want the background to match the tab content background */\n    background: var(--jp-layout-color1);\n    min-height: calc(var(--jp-widgets-horizontal-tab-height) + 2 * var(--jp-border-width));\n    transform: translateY(var(--jp-border-width));\n    overflow: visible;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current:before {\n    position: absolute;\n    top: calc(-1 * var(--jp-border-width));\n    left: calc(-1 * var(--jp-border-width));\n    content: '';\n    height: var(--jp-widgets-horizontal-tab-top-border);\n    width: calc(100% + 2 * var(--jp-border-width));\n    background: var(--jp-brand-color1);\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:first-child {\n    margin-left: 0;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:hover:not(.p-mod-current) {\n    background: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon {\n    margin-left: 4px;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon:before {\n    font-family: FontAwesome;\n    content: '\\f00d'; /* close */\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel,\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {\n    line-height: var(--jp-widgets-horizontal-tab-height);\n}\n\n/* Accordion Widget */\n\n.p-Collapse {\n    display: flex;\n    flex-direction: column;\n    align-items: stretch;\n}\n\n.p-Collapse-header {\n    padding: var(--jp-widgets-input-padding);\n    cursor: pointer;\n    color: var(--jp-ui-font-color2);\n    background-color: var(--jp-layout-color2);\n    border: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    padding: calc(var(--jp-widgets-container-padding) * 2 / 3) var(--jp-widgets-container-padding);\n    font-weight: bold;\n}\n\n.p-Collapse-header:hover {\n    background-color: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n}\n\n.p-Collapse-open > .p-Collapse-header {\n    background-color: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color0);\n    cursor: default;\n    border-bottom: none;\n}\n\n.p-Collapse .p-Collapse-header::before {\n    content: '\\f0da\\00A0';  /* caret-right, non-breaking space */\n    display: inline-block;\n    font: normal normal normal 14px/1 FontAwesome;\n    font-size: inherit;\n    text-rendering: auto;\n    -webkit-font-smoothing: antialiased;\n    -moz-osx-font-smoothing: grayscale;\n}\n\n.p-Collapse-open > .p-Collapse-header::before {\n    content: '\\f0d7\\00A0'; /* caret-down, non-breaking space */\n}\n\n.p-Collapse-contents {\n    padding: var(--jp-widgets-container-padding);\n    background-color: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n    border-left: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    border-right: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    border-bottom: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    overflow: auto;\n}\n\n.p-Accordion {\n    display: flex;\n    flex-direction: column;\n    align-items: stretch;\n}\n\n.p-Accordion .p-Collapse {\n    margin-bottom: 0;\n}\n\n.p-Accordion .p-Collapse + .p-Collapse {\n    margin-top: 4px;\n}\n\n\n\n/* HTML widget */\n\n.widget-html, .widget-htmlmath {\n    font-size: var(--jp-widgets-font-size);\n}\n\n.widget-html > .widget-html-content, .widget-htmlmath > .widget-html-content {\n    /* Fill out the area in the HTML widget */\n    align-self: stretch;\n    flex-grow: 1;\n    flex-shrink: 1;\n    /* Makes sure the baseline is still aligned with other elements */\n    line-height: var(--jp-widgets-inline-height);\n    /* Make it possible to have absolutely-positioned elements in the html */\n    position: relative;\n}\n","/* This file has code derived from PhosphorJS CSS files, as noted below. The license for this PhosphorJS code is:\n\nCopyright (c) 2014-2017, PhosphorJS Contributors\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n*/\n\n/*\n * The following section is derived from https://github.com/phosphorjs/phosphor/blob/23b9d075ebc5b73ab148b6ebfc20af97f85714c4/packages/widgets/style/tabbar.css \n * We've scoped the rules so that they are consistent with exactly our code.\n */\n\n.jupyter-widgets.widget-tab > .p-TabBar {\n  display: flex;\n  -webkit-user-select: none;\n  -moz-user-select: none;\n  -ms-user-select: none;\n  user-select: none;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] {\n  flex-direction: row;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] {\n  flex-direction: column;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {\n  margin: 0;\n  padding: 0;\n  display: flex;\n  flex: 1 1 auto;\n  list-style-type: none;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] > .p-TabBar-content {\n  flex-direction: row;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] > .p-TabBar-content {\n  flex-direction: column;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {\n  display: flex;\n  flex-direction: row;\n  box-sizing: border-box;\n  overflow: hidden;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {\n  flex: 0 0 auto;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel {\n  flex: 1 1 auto;\n  overflow: hidden;\n  white-space: nowrap;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-hidden {\n  display: none !important;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab {\n  position: relative;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='horizontal'] .p-TabBar-tab {\n  left: 0;\n  transition: left 150ms ease;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='vertical'] .p-TabBar-tab {\n  top: 0;\n  transition: top 150ms ease;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab.p-mod-dragging {\n  transition: none;\n}\n\n/* End tabbar.css */\n"]} */",
+ "ok": true,
+ "headers": [
+ [
+ "content-type",
+ "text/css"
+ ]
+ ],
+ "status": 200,
+ "status_text": ""
+ }
+ },
+ "base_uri": "https://localhost:8080/",
+ "height": 270
+ }
+ },
+ "source": [
+ "# [ 0, 2, 1, 34, 4, 24]\n",
+ "def recon_check(original, recon):\n",
+ " fig = plt.figure()\n",
+ " ax0 = fig.add_subplot(121)\n",
+ " plt.imshow(original, cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " ax1 = fig.add_subplot(122)\n",
+ " plt.imshow(recon , cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " \n",
+ "def f(x):\n",
+ " fig = plt.figure()\n",
+ " ax0 = fig.add_subplot(121)\n",
+ " plt.imshow(originals[x], cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " ax1 = fig.add_subplot(122)\n",
+ " plt.imshow(recons[x], cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " \n",
+ "interact(f, x=widgets.IntSlider(min=0, max=xs.shape[0], step=1, value=0))"
+ ],
+ "execution_count": 341,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ef9cf802f2784126bef5615081093250",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "interactive(children=(IntSlider(value=0, description='x', max=256), Output()), _dom_classes=('widget-interact'…"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 341
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "EVCzRnuIzNpL",
+ "colab_type": "code",
+ "outputId": "319975aa-e56c-4814-c72f-15f1d8187c91",
+ "colab": {
+ "resources": {
+ "http://localhost:8080/nbextensions/google.colab/colabwidgets/controls.css": {
+ "data": "/* Copyright (c) Jupyter Development Team.
 * Distributed under the terms of the Modified BSD License.
 */

 /* We import all of these together in a single css file because the Webpack
loader sees only one file at a time. This allows postcss to see the variable
definitions when they are used. */

 /*-----------------------------------------------------------------------------
| Copyright (c) Jupyter Development Team.
| Distributed under the terms of the Modified BSD License.
|----------------------------------------------------------------------------*/

 /*
This file is copied from the JupyterLab project to define default styling for
when the widget styling is compiled down to eliminate CSS variables. We make one
change - we comment out the font import below.
*/

 /**
 * The material design colors are adapted from google-material-color v1.2.6
 * https://github.com/danlevan/google-material-color
 * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/dist/palette.var.css
 *
 * The license for the material design color CSS variables is as follows (see
 * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/LICENSE)
 *
 * The MIT License (MIT)
 *
 * Copyright (c) 2014 Dan Le Van
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

 /*
The following CSS variables define the main, public API for styling JupyterLab.
These variables should be used by all plugins wherever possible. In other
words, plugins should not define custom colors, sizes, etc unless absolutely
necessary. This enables users to change the visual theme of JupyterLab
by changing these variables.

Many variables appear in an ordered sequence (0,1,2,3). These sequences
are designed to work well together, so for example, `--jp-border-color1` should
be used with `--jp-layout-color1`. The numbers have the following meanings:

* 0: super-primary, reserved for special emphasis
* 1: primary, most important under normal situations
* 2: secondary, next most important under normal situations
* 3: tertiary, next most important under normal situations

Throughout JupyterLab, we are mostly following principles from Google's
Material Design when selecting colors. We are not, however, following
all of MD as it is not optimized for dense, information rich UIs.
*/

 /*
 * Optional monospace font for input/output prompt.
 */

 /* Commented out in ipywidgets since we don't need it. */

 /* @import url('https://fonts.googleapis.com/css?family=Roboto+Mono'); */

 /*
 * Added for compabitility with output area
 */

 :root {

  /* Borders

  The following variables, specify the visual styling of borders in JupyterLab.
   */

  /* UI Fonts

  The UI font CSS variables are used for the typography all of the JupyterLab
  user interface elements that are not directly user generated content.
  */ /* Base font size */ /* Ensures px perfect FontAwesome icons */

  /* Use these font colors against the corresponding main layout colors.
     In a light theme, these go from dark to light.
  */

  /* Use these against the brand/accent/warn/error colors.
     These will typically go from light to darker, in both a dark and light theme
   */

  /* Content Fonts

  Content font variables are used for typography of user generated content.
  */ /* Base font size */


  /* Layout

  The following are the main layout colors use in JupyterLab. In a light
  theme these would go from light to dark.
  */

  /* Brand/accent */

  /* State colors (warn, error, success, info) */

  /* Cell specific styles */
  /* A custom blend of MD grey and blue 600
   * See https://meyerweb.com/eric/tools/color-blend/#546E7A:1E88E5:5:hex */
  /* A custom blend of MD grey and orange 600
   * https://meyerweb.com/eric/tools/color-blend/#546E7A:F4511E:5:hex */

  /* Notebook specific styles */

  /* Console specific styles */

  /* Toolbar specific styles */
}

 /* Copyright (c) Jupyter Development Team.
 * Distributed under the terms of the Modified BSD License.
 */

 /*
 * We assume that the CSS variables in
 * https://github.com/jupyterlab/jupyterlab/blob/master/src/default-theme/variables.css
 * have been defined.
 */

 /* This file has code derived from PhosphorJS CSS files, as noted below. The license for this PhosphorJS code is:

Copyright (c) 2014-2017, PhosphorJS Contributors
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

*/

 /*
 * The following section is derived from https://github.com/phosphorjs/phosphor/blob/23b9d075ebc5b73ab148b6ebfc20af97f85714c4/packages/widgets/style/tabbar.css 
 * We've scoped the rules so that they are consistent with exactly our code.
 */

 .jupyter-widgets.widget-tab > .p-TabBar {
  display: -webkit-box;
  display: -ms-flexbox;
  display: flex;
  -webkit-user-select: none;
  -moz-user-select: none;
  -ms-user-select: none;
  user-select: none;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] {
  -webkit-box-orient: horizontal;
  -webkit-box-direction: normal;
      -ms-flex-direction: row;
          flex-direction: row;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] {
  -webkit-box-orient: vertical;
  -webkit-box-direction: normal;
      -ms-flex-direction: column;
          flex-direction: column;
}

 .jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {
  margin: 0;
  padding: 0;
  display: -webkit-box;
  display: -ms-flexbox;
  display: flex;
  -webkit-box-flex: 1;
      -ms-flex: 1 1 auto;
          flex: 1 1 auto;
  list-style-type: none;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] > .p-TabBar-content {
  -webkit-box-orient: horizontal;
  -webkit-box-direction: normal;
      -ms-flex-direction: row;
          flex-direction: row;
}

 .jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] > .p-TabBar-content {
  -webkit-box-orient: vertical;
  -webkit-box-direction: normal;
      -ms-flex-direction: column;
          flex-direction: column;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {
  display: -webkit-box;
  display: -ms-flexbox;
  display: flex;
  -webkit-box-orient: horizontal;
  -webkit-box-direction: normal;
      -ms-flex-direction: row;
          flex-direction: row;
  -webkit-box-sizing: border-box;
          box-sizing: border-box;
  overflow: hidden;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,
.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {
  -webkit-box-flex: 0;
      -ms-flex: 0 0 auto;
          flex: 0 0 auto;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel {
  -webkit-box-flex: 1;
      -ms-flex: 1 1 auto;
          flex: 1 1 auto;
  overflow: hidden;
  white-space: nowrap;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-hidden {
  display: none !important;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab {
  position: relative;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='horizontal'] .p-TabBar-tab {
  left: 0;
  -webkit-transition: left 150ms ease;
  transition: left 150ms ease;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='vertical'] .p-TabBar-tab {
  top: 0;
  -webkit-transition: top 150ms ease;
  transition: top 150ms ease;
}

 .jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab.p-mod-dragging {
  -webkit-transition: none;
  transition: none;
}

 /* End tabbar.css */

 :root { /* margin between inline elements */

    /* From Material Design Lite */
}

 .jupyter-widgets {
    margin: 2px;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    color: black;
    overflow: visible;
}

 .jupyter-widgets.jupyter-widgets-disconnected::before {
    line-height: 28px;
    height: 28px;
}

 .jp-Output-result > .jupyter-widgets {
    margin-left: 0;
    margin-right: 0;
}

 /* vbox and hbox */

 .widget-inline-hbox {
    /* Horizontal widgets */
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: horizontal;
    -webkit-box-direction: normal;
        -ms-flex-direction: row;
            flex-direction: row;
    -webkit-box-align: baseline;
        -ms-flex-align: baseline;
            align-items: baseline;
}

 .widget-inline-vbox {
    /* Vertical Widgets */
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: center;
        -ms-flex-align: center;
            align-items: center;
}

 .widget-box {
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    margin: 0;
    overflow: auto;
}

 .widget-gridbox {
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    display: grid;
    margin: 0;
    overflow: auto;
}

 .widget-hbox {
    -webkit-box-orient: horizontal;
    -webkit-box-direction: normal;
        -ms-flex-direction: row;
            flex-direction: row;
}

 .widget-vbox {
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
}

 /* General Button Styling */

 .jupyter-button {
    padding-left: 10px;
    padding-right: 10px;
    padding-top: 0px;
    padding-bottom: 0px;
    display: inline-block;
    white-space: nowrap;
    overflow: hidden;
    text-overflow: ellipsis;
    text-align: center;
    font-size: 13px;
    cursor: pointer;

    height: 28px;
    border: 0px solid;
    line-height: 28px;
    -webkit-box-shadow: none;
            box-shadow: none;

    color: rgba(0, 0, 0, .8);
    background-color: #EEEEEE;
    border-color: #E0E0E0;
    border: none;
}

 .jupyter-button i.fa {
    margin-right: 4px;
    pointer-events: none;
}

 .jupyter-button:empty:before {
    content: "\200b"; /* zero-width space */
}

 .jupyter-widgets.jupyter-button:disabled {
    opacity: 0.6;
}

 .jupyter-button i.fa.center {
    margin-right: 0;
}

 .jupyter-button:hover:enabled, .jupyter-button:focus:enabled {
    /* MD Lite 2dp shadow */
    -webkit-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .14),
                0 3px 1px -2px rgba(0, 0, 0, .2),
                0 1px 5px 0 rgba(0, 0, 0, .12);
            box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .14),
                0 3px 1px -2px rgba(0, 0, 0, .2),
                0 1px 5px 0 rgba(0, 0, 0, .12);
}

 .jupyter-button:active, .jupyter-button.mod-active {
    /* MD Lite 4dp shadow */
    -webkit-box-shadow: 0 4px 5px 0 rgba(0, 0, 0, .14),
                0 1px 10px 0 rgba(0, 0, 0, .12),
                0 2px 4px -1px rgba(0, 0, 0, .2);
            box-shadow: 0 4px 5px 0 rgba(0, 0, 0, .14),
                0 1px 10px 0 rgba(0, 0, 0, .12),
                0 2px 4px -1px rgba(0, 0, 0, .2);
    color: rgba(0, 0, 0, .8);
    background-color: #BDBDBD;
}

 .jupyter-button:focus:enabled {
    outline: 1px solid #64B5F6;
}

 /* Button "Primary" Styling */

 .jupyter-button.mod-primary {
    color: rgba(255, 255, 255, 1.0);
    background-color: #2196F3;
}

 .jupyter-button.mod-primary.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #1976D2;
}

 .jupyter-button.mod-primary:active {
    color: rgba(255, 255, 255, 1);
    background-color: #1976D2;
}

 /* Button "Success" Styling */

 .jupyter-button.mod-success {
    color: rgba(255, 255, 255, 1.0);
    background-color: #4CAF50;
}

 .jupyter-button.mod-success.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #388E3C;
 }

 .jupyter-button.mod-success:active {
    color: rgba(255, 255, 255, 1);
    background-color: #388E3C;
 }

 /* Button "Info" Styling */

 .jupyter-button.mod-info {
    color: rgba(255, 255, 255, 1.0);
    background-color: #00BCD4;
}

 .jupyter-button.mod-info.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #0097A7;
}

 .jupyter-button.mod-info:active {
    color: rgba(255, 255, 255, 1);
    background-color: #0097A7;
}

 /* Button "Warning" Styling */

 .jupyter-button.mod-warning {
    color: rgba(255, 255, 255, 1.0);
    background-color: #FF9800;
}

 .jupyter-button.mod-warning.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #F57C00;
}

 .jupyter-button.mod-warning:active {
    color: rgba(255, 255, 255, 1);
    background-color: #F57C00;
}

 /* Button "Danger" Styling */

 .jupyter-button.mod-danger {
    color: rgba(255, 255, 255, 1.0);
    background-color: #F44336;
}

 .jupyter-button.mod-danger.mod-active {
    color: rgba(255, 255, 255, 1);
    background-color: #D32F2F;
}

 .jupyter-button.mod-danger:active {
    color: rgba(255, 255, 255, 1);
    background-color: #D32F2F;
}

 /* Widget Button*/

 .widget-button, .widget-toggle-button {
    width: 148px;
}

 /* Widget Label Styling */

 /* Override Bootstrap label css */

 .jupyter-widgets label {
    margin-bottom: 0;
    margin-bottom: initial;
}

 .widget-label-basic {
    /* Basic Label */
    color: black;
    font-size: 13px;
    overflow: hidden;
    text-overflow: ellipsis;
    white-space: nowrap;
    line-height: 28px;
}

 .widget-label {
    /* Label */
    color: black;
    font-size: 13px;
    overflow: hidden;
    text-overflow: ellipsis;
    white-space: nowrap;
    line-height: 28px;
}

 .widget-inline-hbox .widget-label {
    /* Horizontal Widget Label */
    color: black;
    text-align: right;
    margin-right: 8px;
    width: 80px;
    -ms-flex-negative: 0;
        flex-shrink: 0;
}

 .widget-inline-vbox .widget-label {
    /* Vertical Widget Label */
    color: black;
    text-align: center;
    line-height: 28px;
}

 /* Widget Readout Styling */

 .widget-readout {
    color: black;
    font-size: 13px;
    height: 28px;
    line-height: 28px;
    overflow: hidden;
    white-space: nowrap;
    text-align: center;
}

 .widget-readout.overflow {
    /* Overflowing Readout */

    /* From Material Design Lite
        shadow-key-umbra-opacity: 0.2;
        shadow-key-penumbra-opacity: 0.14;
        shadow-ambient-shadow-opacity: 0.12;
     */
    -webkit-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .2),
                        0 3px 1px -2px rgba(0, 0, 0, .14),
                        0 1px 5px 0 rgba(0, 0, 0, .12);

    box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .2),
                0 3px 1px -2px rgba(0, 0, 0, .14),
                0 1px 5px 0 rgba(0, 0, 0, .12);
}

 .widget-inline-hbox .widget-readout {
    /* Horizontal Readout */
    text-align: center;
    max-width: 148px;
    min-width: 72px;
    margin-left: 4px;
}

 .widget-inline-vbox .widget-readout {
    /* Vertical Readout */
    margin-top: 4px;
    /* as wide as the widget */
    width: inherit;
}

 /* Widget Checkbox Styling */

 .widget-checkbox {
    width: 300px;
    height: 28px;
    line-height: 28px;
}

 .widget-checkbox input[type="checkbox"] {
    margin: 0px 8px 0px 0px;
    line-height: 28px;
    font-size: large;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 0;
        flex-shrink: 0;
    -ms-flex-item-align: center;
        align-self: center;
}

 /* Widget Valid Styling */

 .widget-valid {
    height: 28px;
    line-height: 28px;
    width: 148px;
    font-size: 13px;
}

 .widget-valid i:before {
    line-height: 28px;
    margin-right: 4px;
    margin-left: 4px;

    /* from the fa class in FontAwesome: https://github.com/FortAwesome/Font-Awesome/blob/49100c7c3a7b58d50baa71efef11af41a66b03d3/css/font-awesome.css#L14 */
    display: inline-block;
    font: normal normal normal 14px/1 FontAwesome;
    font-size: inherit;
    text-rendering: auto;
    -webkit-font-smoothing: antialiased;
    -moz-osx-font-smoothing: grayscale;
}

 .widget-valid.mod-valid i:before {
    content: "\f00c";
    color: green;
}

 .widget-valid.mod-invalid i:before {
    content: "\f00d";
    color: red;
}

 .widget-valid.mod-valid .widget-valid-readout {
    display: none;
}

 /* Widget Text and TextArea Stying */

 .widget-textarea, .widget-text {
    width: 300px;
}

 .widget-text input[type="text"], .widget-text input[type="number"]{
    height: 28px;
    line-height: 28px;
}

 .widget-text input[type="text"]:disabled, .widget-text input[type="number"]:disabled, .widget-textarea textarea:disabled {
    opacity: 0.6;
}

 .widget-text input[type="text"], .widget-text input[type="number"], .widget-textarea textarea {
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    border: 1px solid #9E9E9E;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    padding: 4px 8px;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    -ms-flex-negative: 1;
        flex-shrink: 1;
    outline: none !important;
}

 .widget-textarea textarea {
    height: inherit;
    width: inherit;
}

 .widget-text input:focus, .widget-textarea textarea:focus {
    border-color: #64B5F6;
}

 /* Widget Slider */

 .widget-slider .ui-slider {
    /* Slider Track */
    border: 1px solid #BDBDBD;
    background: #BDBDBD;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    position: relative;
    border-radius: 0px;
}

 .widget-slider .ui-slider .ui-slider-handle {
    /* Slider Handle */
    outline: none !important; /* focused slider handles are colored - see below */
    position: absolute;
    background-color: white;
    border: 1px solid #9E9E9E;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    z-index: 1;
    background-image: none; /* Override jquery-ui */
}

 /* Override jquery-ui */

 .widget-slider .ui-slider .ui-slider-handle:hover, .widget-slider .ui-slider .ui-slider-handle:focus {
    background-color: #2196F3;
    border: 1px solid #2196F3;
}

 .widget-slider .ui-slider .ui-slider-handle:active {
    background-color: #2196F3;
    border-color: #2196F3;
    z-index: 2;
    -webkit-transform: scale(1.2);
            transform: scale(1.2);
}

 .widget-slider  .ui-slider .ui-slider-range {
    /* Interval between the two specified value of a double slider */
    position: absolute;
    background: #2196F3;
    z-index: 0;
}

 /* Shapes of Slider Handles */

 .widget-hslider .ui-slider .ui-slider-handle {
    width: 16px;
    height: 16px;
    margin-top: -7px;
    margin-left: -7px;
    border-radius: 50%;
    top: 0;
}

 .widget-vslider .ui-slider .ui-slider-handle {
    width: 16px;
    height: 16px;
    margin-bottom: -7px;
    margin-left: -7px;
    border-radius: 50%;
    left: 0;
}

 .widget-hslider .ui-slider .ui-slider-range {
    height: 8px;
    margin-top: -3px;
}

 .widget-vslider .ui-slider .ui-slider-range {
    width: 8px;
    margin-left: -3px;
}

 /* Horizontal Slider */

 .widget-hslider {
    width: 300px;
    height: 28px;
    line-height: 28px;

    /* Override the align-items baseline. This way, the description and readout
    still seem to align their baseline properly, and we don't have to have
    align-self: stretch in the .slider-container. */
    -webkit-box-align: center;
        -ms-flex-align: center;
            align-items: center;
}

 .widgets-slider .slider-container {
    overflow: visible;
}

 .widget-hslider .slider-container {
    height: 28px;
    margin-left: 6px;
    margin-right: 6px;
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
}

 .widget-hslider .ui-slider {
    /* Inner, invisible slide div */
    height: 4px;
    margin-top: 12px;
    width: 100%;
}

 /* Vertical Slider */

 .widget-vbox .widget-label {
    height: 28px;
    line-height: 28px;
}

 .widget-vslider {
    /* Vertical Slider */
    height: 200px;
    width: 72px;
}

 .widget-vslider .slider-container {
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
    margin-left: auto;
    margin-right: auto;
    margin-bottom: 6px;
    margin-top: 6px;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
}

 .widget-vslider .ui-slider-vertical {
    /* Inner, invisible slide div */
    width: 4px;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    margin-left: auto;
    margin-right: auto;
}

 /* Widget Progress Styling */

 .progress-bar {
    -webkit-transition: none;
    transition: none;
}

 .progress-bar {
    height: 28px;
}

 .progress-bar {
    background-color: #2196F3;
}

 .progress-bar-success {
    background-color: #4CAF50;
}

 .progress-bar-info {
    background-color: #00BCD4;
}

 .progress-bar-warning {
    background-color: #FF9800;
}

 .progress-bar-danger {
    background-color: #F44336;
}

 .progress {
    background-color: #EEEEEE;
    border: none;
    -webkit-box-shadow: none;
            box-shadow: none;
}

 /* Horisontal Progress */

 .widget-hprogress {
    /* Progress Bar */
    height: 28px;
    line-height: 28px;
    width: 300px;
    -webkit-box-align: center;
        -ms-flex-align: center;
            align-items: center;

}

 .widget-hprogress .progress {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    margin-top: 4px;
    margin-bottom: 4px;
    -ms-flex-item-align: stretch;
        align-self: stretch;
    /* Override bootstrap style */
    height: auto;
    height: initial;
}

 /* Vertical Progress */

 .widget-vprogress {
    height: 200px;
    width: 72px;
}

 .widget-vprogress .progress {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    width: 20px;
    margin-left: auto;
    margin-right: auto;
    margin-bottom: 0;
}

 /* Select Widget Styling */

 .widget-dropdown {
    height: 28px;
    width: 300px;
    line-height: 28px;
}

 .widget-dropdown > select {
    padding-right: 20px;
    border: 1px solid #9E9E9E;
    border-radius: 0;
    height: inherit;
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    outline: none !important;
    -webkit-box-shadow: none;
            box-shadow: none;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    vertical-align: top;
    padding-left: 8px;
	appearance: none;
	-webkit-appearance: none;
	-moz-appearance: none;
    background-repeat: no-repeat;
	background-size: 20px;
	background-position: right center;
    background-image: url("data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0idXRmLTgiPz4KPCEtLSBHZW5lcmF0b3I6IEFkb2JlIElsbHVzdHJhdG9yIDE5LjIuMSwgU1ZHIEV4cG9ydCBQbHVnLUluIC4gU1ZHIFZlcnNpb246IDYuMDAgQnVpbGQgMCkgIC0tPgo8c3ZnIHZlcnNpb249IjEuMSIgaWQ9IkxheWVyXzEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgeG1sbnM6eGxpbms9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkveGxpbmsiIHg9IjBweCIgeT0iMHB4IgoJIHZpZXdCb3g9IjAgMCAxOCAxOCIgc3R5bGU9ImVuYWJsZS1iYWNrZ3JvdW5kOm5ldyAwIDAgMTggMTg7IiB4bWw6c3BhY2U9InByZXNlcnZlIj4KPHN0eWxlIHR5cGU9InRleHQvY3NzIj4KCS5zdDB7ZmlsbDpub25lO30KPC9zdHlsZT4KPHBhdGggZD0iTTUuMiw1LjlMOSw5LjdsMy44LTMuOGwxLjIsMS4ybC00LjksNWwtNC45LTVMNS4yLDUuOXoiLz4KPHBhdGggY2xhc3M9InN0MCIgZD0iTTAtMC42aDE4djE4SDBWLTAuNnoiLz4KPC9zdmc+Cg");
}

 .widget-dropdown > select:focus {
    border-color: #64B5F6;
}

 .widget-dropdown > select:disabled {
    opacity: 0.6;
}

 /* To disable the dotted border in Firefox around select controls.
   See http://stackoverflow.com/a/18853002 */

 .widget-dropdown > select:-moz-focusring {
    color: transparent;
    text-shadow: 0 0 0 #000;
}

 /* Select and SelectMultiple */

 .widget-select {
    width: 300px;
    line-height: 28px;

    /* Because Firefox defines the baseline of a select as the bottom of the
    control, we align the entire control to the top and add padding to the
    select to get an approximate first line baseline alignment. */
    -webkit-box-align: start;
        -ms-flex-align: start;
            align-items: flex-start;
}

 .widget-select > select {
    border: 1px solid #9E9E9E;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    -webkit-box-flex: 1;
        -ms-flex: 1 1 148px;
            flex: 1 1 148px;
    outline: none !important;
    overflow: auto;
    height: inherit;

    /* Because Firefox defines the baseline of a select as the bottom of the
    control, we align the entire control to the top and add padding to the
    select to get an approximate first line baseline alignment. */
    padding-top: 5px;
}

 .widget-select > select:focus {
    border-color: #64B5F6;
}

 .wiget-select > select > option {
    padding-left: 4px;
    line-height: 28px;
    /* line-height doesn't work on some browsers for select options */
    padding-top: calc(28px - var(--jp-widgets-font-size) / 2);
    padding-bottom: calc(28px - var(--jp-widgets-font-size) / 2);
}

 /* Toggle Buttons Styling */

 .widget-toggle-buttons {
    line-height: 28px;
}

 .widget-toggle-buttons .widget-toggle-button {
    margin-left: 2px;
    margin-right: 2px;
}

 .widget-toggle-buttons .jupyter-button:disabled {
    opacity: 0.6;
}

 /* Radio Buttons Styling */

 .widget-radio {
    width: 300px;
    line-height: 28px;
}

 .widget-radio-box {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    margin-bottom: 8px;
}

 .widget-radio-box label {
    height: 20px;
    line-height: 20px;
    font-size: 13px;
}

 .widget-radio-box input {
    height: 20px;
    line-height: 20px;
    margin: 0 8px 0 1px;
    float: left;
}

 /* Color Picker Styling */

 .widget-colorpicker {
    width: 300px;
    height: 28px;
    line-height: 28px;
}

 .widget-colorpicker > .widget-colorpicker-input {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 1;
        flex-shrink: 1;
    min-width: 72px;
}

 .widget-colorpicker input[type="color"] {
    width: 28px;
    height: 28px;
    padding: 0 2px; /* make the color square actually square on Chrome on OS X */
    background: white;
    color: rgba(0, 0, 0, .8);
    border: 1px solid #9E9E9E;
    border-left: none;
    -webkit-box-flex: 0;
        -ms-flex-positive: 0;
            flex-grow: 0;
    -ms-flex-negative: 0;
        flex-shrink: 0;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    -ms-flex-item-align: stretch;
        align-self: stretch;
    outline: none !important;
}

 .widget-colorpicker.concise input[type="color"] {
    border-left: 1px solid #9E9E9E;
}

 .widget-colorpicker input[type="color"]:focus, .widget-colorpicker input[type="text"]:focus {
    border-color: #64B5F6;
}

 .widget-colorpicker input[type="text"] {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    outline: none !important;
    height: 28px;
    line-height: 28px;
    background: white;
    color: rgba(0, 0, 0, .8);
    border: 1px solid #9E9E9E;
    font-size: 13px;
    padding: 4px 8px;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    -ms-flex-negative: 1;
        flex-shrink: 1;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
}

 .widget-colorpicker input[type="text"]:disabled {
    opacity: 0.6;
}

 /* Date Picker Styling */

 .widget-datepicker {
    width: 300px;
    height: 28px;
    line-height: 28px;
}

 .widget-datepicker input[type="date"] {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 1;
        flex-shrink: 1;
    min-width: 0; /* This makes it possible for the flexbox to shrink this input */
    outline: none !important;
    height: 28px;
    border: 1px solid #9E9E9E;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    font-size: 13px;
    padding: 4px 8px;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
}

 .widget-datepicker input[type="date"]:focus {
    border-color: #64B5F6;
}

 .widget-datepicker input[type="date"]:invalid {
    border-color: #FF9800;
}

 .widget-datepicker input[type="date"]:disabled {
    opacity: 0.6;
}

 /* Play Widget */

 .widget-play {
    width: 148px;
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
}

 .widget-play .jupyter-button {
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    height: auto;
}

 .widget-play .jupyter-button:disabled {
    opacity: 0.6;
}

 /* Tab Widget */

 .jupyter-widgets.widget-tab {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
}

 .jupyter-widgets.widget-tab > .p-TabBar {
    /* Necessary so that a tab can be shifted down to overlay the border of the box below. */
    overflow-x: visible;
    overflow-y: visible;
}

 .jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {
    /* Make sure that the tab grows from bottom up */
    -webkit-box-align: end;
        -ms-flex-align: end;
            align-items: flex-end;
    min-width: 0;
    min-height: 0;
}

 .jupyter-widgets.widget-tab > .widget-tab-contents {
    width: 100%;
    -webkit-box-sizing: border-box;
            box-sizing: border-box;
    margin: 0;
    background: white;
    color: rgba(0, 0, 0, .8);
    border: 1px solid #9E9E9E;
    padding: 15px;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    overflow: auto;
}

 .jupyter-widgets.widget-tab > .p-TabBar {
    font: 13px Helvetica, Arial, sans-serif;
    min-height: 25px;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {
    -webkit-box-flex: 0;
        -ms-flex: 0 1 144px;
            flex: 0 1 144px;
    min-width: 35px;
    min-height: 25px;
    line-height: 24px;
    margin-left: -1px;
    padding: 0px 10px;
    background: #EEEEEE;
    color: rgba(0, 0, 0, .5);
    border: 1px solid #9E9E9E;
    border-bottom: none;
    position: relative;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current {
    color: rgba(0, 0, 0, 1.0);
    /* We want the background to match the tab content background */
    background: white;
    min-height: 26px;
    -webkit-transform: translateY(1px);
            transform: translateY(1px);
    overflow: visible;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current:before {
    position: absolute;
    top: -1px;
    left: -1px;
    content: '';
    height: 2px;
    width: calc(100% + 2px);
    background: #2196F3;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:first-child {
    margin-left: 0;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:hover:not(.p-mod-current) {
    background: white;
    color: rgba(0, 0, 0, .8);
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon {
    margin-left: 4px;
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon:before {
    font-family: FontAwesome;
    content: '\f00d'; /* close */
}

 .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,
.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel,
.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {
    line-height: 24px;
}

 /* Accordion Widget */

 .p-Collapse {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
}

 .p-Collapse-header {
    padding: 4px;
    cursor: pointer;
    color: rgba(0, 0, 0, .5);
    background-color: #EEEEEE;
    border: 1px solid #9E9E9E;
    padding: 10px 15px;
    font-weight: bold;
}

 .p-Collapse-header:hover {
    background-color: white;
    color: rgba(0, 0, 0, .8);
}

 .p-Collapse-open > .p-Collapse-header {
    background-color: white;
    color: rgba(0, 0, 0, 1.0);
    cursor: default;
    border-bottom: none;
}

 .p-Collapse .p-Collapse-header::before {
    content: '\f0da\00A0';  /* caret-right, non-breaking space */
    display: inline-block;
    font: normal normal normal 14px/1 FontAwesome;
    font-size: inherit;
    text-rendering: auto;
    -webkit-font-smoothing: antialiased;
    -moz-osx-font-smoothing: grayscale;
}

 .p-Collapse-open > .p-Collapse-header::before {
    content: '\f0d7\00A0'; /* caret-down, non-breaking space */
}

 .p-Collapse-contents {
    padding: 15px;
    background-color: white;
    color: rgba(0, 0, 0, .8);
    border-left: 1px solid #9E9E9E;
    border-right: 1px solid #9E9E9E;
    border-bottom: 1px solid #9E9E9E;
    overflow: auto;
}

 .p-Accordion {
    display: -webkit-box;
    display: -ms-flexbox;
    display: flex;
    -webkit-box-orient: vertical;
    -webkit-box-direction: normal;
        -ms-flex-direction: column;
            flex-direction: column;
    -webkit-box-align: stretch;
        -ms-flex-align: stretch;
            align-items: stretch;
}

 .p-Accordion .p-Collapse {
    margin-bottom: 0;
}

 .p-Accordion .p-Collapse + .p-Collapse {
    margin-top: 4px;
}

 /* HTML widget */

 .widget-html, .widget-htmlmath {
    font-size: 13px;
}

 .widget-html > .widget-html-content, .widget-htmlmath > .widget-html-content {
    /* Fill out the area in the HTML widget */
    -ms-flex-item-align: stretch;
        align-self: stretch;
    -webkit-box-flex: 1;
        -ms-flex-positive: 1;
            flex-grow: 1;
    -ms-flex-negative: 1;
        flex-shrink: 1;
    /* Makes sure the baseline is still aligned with other elements */
    line-height: 28px;
    /* Make it possible to have absolutely-positioned elements in the html */
    position: relative;
}

/*# sourceMappingURL=data:application/json;base64,{"version":3,"sources":["../node_modules/@jupyter-widgets/controls/css/widgets.css","../node_modules/@jupyter-widgets/controls/css/labvariables.css","../node_modules/@jupyter-widgets/controls/css/materialcolors.css","../node_modules/@jupyter-widgets/controls/css/widgets-base.css","../node_modules/@jupyter-widgets/controls/css/phosphor.css"],"names":[],"mappings":"AAAA;;GAEG;;CAEF;;kCAEiC;;CCNlC;;;+EAG+E;;CAE/E;;;;EAIE;;CCTF;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;;CDhBH;;;;;;;;;;;;;;;;;;;EAmBE;;CAGF;;GAEG;;CACF,yDAAyD;;CAC1D,yEAAyE;;CAEzE;;GAEG;;CAOH;;EAEE;;;KAGG;;EAQH;;;;IAIE,CAIwB,oBAAoB,CAGhB,0CAA0C;;EAGxE;;IAEE;;EAOF;;KAEG;;EAOH;;;IAGE,CAWwB,oBAAoB;;;EAU9C;;;;IAIE;;EAOF,kBAAkB;;EAYlB,+CAA+C;;EAsB/C,0BAA0B;EAa1B;4EAC0E;EAE1E;wEACsE;;EAGtE,8BAA8B;;EAK9B,6BAA6B;;EAI7B,6BAA6B;CAQ9B;;CEzMD;;GAEG;;CAEH;;;;GAIG;;CCRH;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;EA8BE;;CAEF;;;GAGG;;CAEH;EACE,qBAAc;EAAd,qBAAc;EAAd,cAAc;EACd,0BAA0B;EAC1B,uBAAuB;EACvB,sBAAsB;EACtB,kBAAkB;CACnB;;CAGD;EACE,+BAAoB;EAApB,8BAAoB;MAApB,wBAAoB;UAApB,oBAAoB;CACrB;;CAGD;EACE,6BAAuB;EAAvB,8BAAuB;MAAvB,2BAAuB;UAAvB,uBAAuB;CACxB;;CAGD;EACE,UAAU;EACV,WAAW;EACX,qBAAc;EAAd,qBAAc;EAAd,cAAc;EACd,oBAAe;MAAf,mBAAe;UAAf,eAAe;EACf,sBAAsB;CACvB;;CAGD;EACE,+BAAoB;EAApB,8BAAoB;MAApB,wBAAoB;UAApB,oBAAoB;CACrB;;CAGD;EACE,6BAAuB;EAAvB,8BAAuB;MAAvB,2BAAuB;UAAvB,uBAAuB;CACxB;;CAGD;EACE,qBAAc;EAAd,qBAAc;EAAd,cAAc;EACd,+BAAoB;EAApB,8BAAoB;MAApB,wBAAoB;UAApB,oBAAoB;EACpB,+BAAuB;UAAvB,uBAAuB;EACvB,iBAAiB;CAClB;;CAGD;;EAEE,oBAAe;MAAf,mBAAe;UAAf,eAAe;CAChB;;CAGD;EACE,oBAAe;MAAf,mBAAe;UAAf,eAAe;EACf,iBAAiB;EACjB,oBAAoB;CACrB;;CAGD;EACE,yBAAyB;CAC1B;;CAGD;EACE,mBAAmB;CACpB;;CAGD;EACE,QAAQ;EACR,oCAA4B;EAA5B,4BAA4B;CAC7B;;CAGD;EACE,OAAO;EACP,mCAA2B;EAA3B,2BAA2B;CAC5B;;CAGD;EACE,yBAAiB;EAAjB,iBAAiB;CAClB;;CAED,oBAAoB;;CD9GpB,QAUqC,oCAAoC;;IA2BrE,+BAA+B;CAIlC;;CAED;IACI,YAAiC;IACjC,+BAAuB;YAAvB,uBAAuB;IACvB,aAA+B;IAC/B,kBAAkB;CACrB;;CAED;IACI,kBAA6C;IAC7C,aAAwC;CAC3C;;CAED;IACI,eAAe;IACf,gBAAgB;CACnB;;CAED,mBAAmB;;CAEnB;IACI,wBAAwB;IACxB,+BAAuB;YAAvB,uBAAuB;IACvB,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,+BAAoB;IAApB,8BAAoB;QAApB,wBAAoB;YAApB,oBAAoB;IACpB,4BAAsB;QAAtB,yBAAsB;YAAtB,sBAAsB;CACzB;;CAED;IACI,sBAAsB;IACtB,+BAAuB;YAAvB,uBAAuB;IACvB,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,0BAAoB;QAApB,uBAAoB;YAApB,oBAAoB;CACvB;;CAED;IACI,+BAAuB;YAAvB,uBAAuB;IACvB,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,UAAU;IACV,eAAe;CAClB;;CAED;IACI,+BAAuB;YAAvB,uBAAuB;IACvB,cAAc;IACd,UAAU;IACV,eAAe;CAClB;;CAED;IACI,+BAAoB;IAApB,8BAAoB;QAApB,wBAAoB;YAApB,oBAAoB;CACvB;;CAED;IACI,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;CAC1B;;CAED,4BAA4B;;CAE5B;IACI,mBAAmB;IACnB,oBAAoB;IACpB,iBAAiB;IACjB,oBAAoB;IACpB,sBAAsB;IACtB,oBAAoB;IACpB,iBAAiB;IACjB,wBAAwB;IACxB,mBAAmB;IACnB,gBAAuC;IACvC,gBAAgB;;IAEhB,aAAwC;IACxC,kBAAkB;IAClB,kBAA6C;IAC7C,yBAAiB;YAAjB,iBAAiB;;IAEjB,yBAAgC;IAChC,0BAA0C;IAC1C,sBAAsC;IACtC,aAAa;CAChB;;CAED;IACI,kBAA8C;IAC9C,qBAAqB;CACxB;;CAED;IACI,iBAAiB,CAAC,sBAAsB;CAC3C;;CAED;IACI,aAA4C;CAC/C;;CAED;IACI,gBAAgB;CACnB;;CAED;IACI,wBAAwB;IACxB;;+CAE+E;YAF/E;;+CAE+E;CAClF;;CAED;IACI,wBAAwB;IACxB;;iDAE6E;YAF7E;;iDAE6E;IAC7E,yBAAgC;IAChC,0BAA0C;CAC7C;;CAED;IACI,2BAA8D;CACjE;;CAED,8BAA8B;;CAE9B;IACI,gCAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED,8BAA8B;;CAE9B;IACI,gCAAwC;IACxC,0BAA2C;CAC9C;;CAED;IACI,8BAAwC;IACxC,0BAA2C;EAC7C;;CAEF;IACI,8BAAwC;IACxC,0BAA2C;EAC7C;;CAED,2BAA2B;;CAE5B;IACI,gCAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED,8BAA8B;;CAE9B;IACI,gCAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED;IACI,8BAAwC;IACxC,0BAAwC;CAC3C;;CAED,6BAA6B;;CAE7B;IACI,gCAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED;IACI,8BAAwC;IACxC,0BAAyC;CAC5C;;CAED,kBAAkB;;CAElB;IACI,aAA4C;CAC/C;;CAED,0BAA0B;;CAE1B,kCAAkC;;CAClC;IACI,iBAAuB;IAAvB,uBAAuB;CAC1B;;CAED;IACI,iBAAiB;IACjB,aAAqC;IACrC,gBAAuC;IACvC,iBAAiB;IACjB,wBAAwB;IACxB,oBAAoB;IACpB,kBAA6C;CAChD;;CAED;IACI,WAAW;IACX,aAAqC;IACrC,gBAAuC;IACvC,iBAAiB;IACjB,wBAAwB;IACxB,oBAAoB;IACpB,kBAA6C;CAChD;;CAED;IACI,6BAA6B;IAC7B,aAAqC;IACrC,kBAAkB;IAClB,kBAA0D;IAC1D,YAA4C;IAC5C,qBAAe;QAAf,eAAe;CAClB;;CAED;IACI,2BAA2B;IAC3B,aAAqC;IACrC,mBAAmB;IACnB,kBAA6C;CAChD;;CAED,4BAA4B;;CAE5B;IACI,aAAuC;IACvC,gBAAuC;IACvC,aAAwC;IACxC,kBAA6C;IAC7C,iBAAiB;IACjB,oBAAoB;IACpB,mBAAmB;CACtB;;CAED;IACI,yBAAyB;;IAEzB;;;;OAIG;IACH;;uDAEoD;;IAMpD;;+CAE4C;CAC/C;;CAED;IACI,wBAAwB;IACxB,mBAAmB;IACnB,iBAAgD;IAChD,gBAA+C;IAC/C,iBAA6C;CAChD;;CAED;IACI,sBAAsB;IACtB,gBAA4C;IAC5C,2BAA2B;IAC3B,eAAe;CAClB;;CAED,6BAA6B;;CAE7B;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,wBAAgE;IAChE,kBAA6C;IAC7C,iBAAiB;IACjB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,4BAAmB;QAAnB,mBAAmB;CACtB;;CAED,0BAA0B;;CAE1B;IACI,aAAwC;IACxC,kBAA6C;IAC7C,aAA4C;IAC5C,gBAAuC;CAC1C;;CAED;IACI,kBAA6C;IAC7C,kBAA8C;IAC9C,iBAA6C;;IAE7C,0JAA0J;IAC1J,sBAAsB;IACtB,8CAA8C;IAC9C,mBAAmB;IACnB,qBAAqB;IACrB,oCAAoC;IACpC,mCAAmC;CACtC;;CAED;IACI,iBAAiB;IACjB,aAAa;CAChB;;CAED;IACI,iBAAiB;IACjB,WAAW;CACd;;CAED;IACI,cAAc;CACjB;;CAED,qCAAqC;;CAErC;IACI,aAAsC;CACzC;;CAED;IACI,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,aAA4C;CAC/C;;CAED;IACI,+BAAuB;YAAvB,uBAAuB;IACvB,0BAAwF;IACxF,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,iBAAsF;IACtF,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,aAAa,CAAC,iEAAiE;IAC/E,qBAAe;QAAf,eAAe;IACf,yBAAyB;CAC5B;;CAED;IACI,gBAAgB;IAChB,eAAe;CAClB;;CAED;IACI,sBAAyD;CAC5D;;CAED,mBAAmB;;CAEnB;IACI,kBAAkB;IAClB,0BAA4E;IAC5E,oBAAoC;IACpC,+BAAuB;YAAvB,uBAAuB;IACvB,mBAAmB;IACnB,mBAAmB;CACtB;;CAED;IACI,mBAAmB;IACnB,yBAAyB,CAAC,oDAAoD;IAC9E,mBAAmB;IACnB,wBAAmE;IACnE,0BAAiG;IACjG,+BAAuB;YAAvB,uBAAuB;IACvB,WAAW;IACX,uBAAuB,CAAC,wBAAwB;CACnD;;CAED,wBAAwB;;CACxB;IACI,0BAA+D;IAC/D,0BAAiG;CACpG;;CAED;IACI,0BAA+D;IAC/D,sBAA2D;IAC3D,WAAW;IACX,8BAAsB;YAAtB,sBAAsB;CACzB;;CAED;IACI,iEAAiE;IACjE,mBAAmB;IACnB,oBAAyD;IACzD,WAAW;CACd;;CAED,8BAA8B;;CAE9B;IACI,YAA4C;IAC5C,aAA6C;IAC7C,iBAAgJ;IAChJ,kBAAqG;IACrG,mBAAmB;IACnB,OAAO;CACV;;CAED;IACI,YAA4C;IAC5C,aAA6C;IAC7C,oBAAuG;IACvG,kBAAiJ;IACjJ,mBAAmB;IACnB,QAAQ;CACX;;CAED;IACI,YAA6D;IAC7D,iBAAyJ;CAC5J;;CAED;IACI,WAA4D;IAC5D,kBAA0J;CAC7J;;CAED,uBAAuB;;CAEvB;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;;IAE7C;;oDAEgD;IAChD,0BAAoB;QAApB,uBAAoB;YAApB,oBAAoB;CACvB;;CAED;IACI,kBAAkB;CACrB;;CAED;IACI,aAAwC;IACxC,iBAAwG;IACxG,kBAAyG;IACzG,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;CAClD;;CAED;IACI,gCAAgC;IAChC,YAAiD;IACjD,iBAAmG;IACnG,YAAY;CACf;;CAED,qBAAqB;;CAErB;IACI,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,qBAAqB;IACrB,cAA0C;IAC1C,YAA2C;CAC9C;;CAED;IACI,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;IAC/C,kBAAkB;IAClB,mBAAmB;IACnB,mBAA0G;IAC1G,gBAAuG;IACvG,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,gCAAgC;IAChC,WAAgD;IAChD,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,kBAAkB;IAClB,mBAAmB;CACtB;;CAED,6BAA6B;;CAE7B;IACI,yBAAyB;IAIzB,iBAAiB;CACpB;;CAED;IACI,aAAwC;CAC3C;;CAED;IACI,0BAAyC;CAC5C;;CAED;IACI,0BAA2C;CAC9C;;CAED;IACI,0BAAwC;CAC3C;;CAED;IACI,0BAAwC;CAC3C;;CAED;IACI,0BAAyC;CAC5C;;CAED;IACI,0BAA0C;IAC1C,aAAa;IACb,yBAAiB;YAAjB,iBAAiB;CACpB;;CAED,yBAAyB;;CAEzB;IACI,kBAAkB;IAClB,aAAwC;IACxC,kBAA6C;IAC7C,aAAsC;IACtC,0BAAoB;QAApB,uBAAoB;YAApB,oBAAoB;;CAEvB;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,gBAA4C;IAC5C,mBAA+C;IAC/C,6BAAoB;QAApB,oBAAoB;IACpB,8BAA8B;IAC9B,aAAgB;IAAhB,gBAAgB;CACnB;;CAED,uBAAuB;;CAEvB;IACI,cAA0C;IAC1C,YAA2C;CAC9C;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,YAA4C;IAC5C,kBAAkB;IAClB,mBAAmB;IACnB,iBAAiB;CACpB;;CAED,2BAA2B;;CAE3B;IACI,aAAwC;IACxC,aAAsC;IACtC,kBAA6C;CAChD;;CAED;IACI,oBAAoB;IACpB,0BAAwF;IACxF,iBAAiB;IACjB,gBAAgB;IAChB,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;IAC/C,aAAa,CAAC,iEAAiE;IAC/E,+BAAuB;YAAvB,uBAAuB;IACvB,yBAAyB;IACzB,yBAAiB;YAAjB,iBAAiB;IACjB,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,oBAAoB;IACpB,kBAAyD;CAC5D,iBAAiB;CACjB,yBAAyB;CACzB,sBAAsB;IACnB,6BAA6B;CAChC,sBAAsB;CACtB,kCAAkC;IAC/B,kuBAAmD;CACtD;;CACD;IACI,sBAAyD;CAC5D;;CAED;IACI,aAA4C;CAC/C;;CAED;6CAC6C;;CAC7C;IACI,mBAAmB;IACnB,wBAAwB;CAC3B;;CAED,+BAA+B;;CAE/B;IACI,aAAsC;IACtC,kBAA6C;;IAE7C;;kEAE8D;IAC9D,yBAAwB;QAAxB,sBAAwB;YAAxB,wBAAwB;CAC3B;;CAED;IACI,0BAAwF;IACxF,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,oBAA+C;QAA/C,oBAA+C;YAA/C,gBAA+C;IAC/C,yBAAyB;IACzB,eAAe;IACf,gBAAgB;;IAEhB;;kEAE8D;IAC9D,iBAAiB;CACpB;;CAED;IACI,sBAAyD;CAC5D;;CAED;IACI,kBAA8C;IAC9C,kBAA6C;IAC7C,kEAAkE;IAClE,0DAAiF;IACjF,6DAAoF;CACvF;;CAID,4BAA4B;;CAE5B;IACI,kBAA6C;CAChD;;CAED;IACI,iBAAsC;IACtC,kBAAuC;CAC1C;;CAED;IACI,aAA4C;CAC/C;;CAED,2BAA2B;;CAE3B;IACI,aAAsC;IACtC,kBAA6C;CAChD;;CAED;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;IACrB,+BAAuB;YAAvB,uBAAuB;IACvB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,mBAA8D;CACjE;;CAED;IACI,aAA4C;IAC5C,kBAAiD;IACjD,gBAAuC;CAC1C;;CAED;IACI,aAA4C;IAC5C,kBAAiD;IACjD,oBAA4D;IAC5D,YAAY;CACf;;CAED,0BAA0B;;CAE1B;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,gBAA+C;CAClD;;CAED;IACI,YAAuC;IACvC,aAAwC;IACxC,eAAe,CAAC,6DAA6D;IAC7E,kBAAqD;IACrD,yBAAqC;IACrC,0BAAwF;IACxF,kBAAkB;IAClB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,+BAAuB;YAAvB,uBAAuB;IACvB,6BAAoB;QAApB,oBAAoB;IACpB,yBAAyB;CAC5B;;CAED;IACI,+BAA6F;CAChG;;CAED;IACI,sBAAyD;CAC5D;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,yBAAyB;IACzB,aAAwC;IACxC,kBAA6C;IAC7C,kBAAqD;IACrD,yBAAqC;IACrC,0BAAwF;IACxF,gBAAuC;IACvC,iBAAsF;IACtF,aAAa,CAAC,iEAAiE;IAC/E,qBAAe;QAAf,eAAe;IACf,+BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,aAA4C;CAC/C;;CAED,yBAAyB;;CAEzB;IACI,aAAsC;IACtC,aAAwC;IACxC,kBAA6C;CAChD;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,aAAa,CAAC,iEAAiE;IAC/E,yBAAyB;IACzB,aAAwC;IACxC,0BAAwF;IACxF,wBAA2D;IAC3D,yBAAqC;IACrC,gBAAuC;IACvC,iBAAsF;IACtF,+BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,sBAAyD;CAC5D;;CAED;IACI,sBAAoC;CACvC;;CAED;IACI,aAA4C;CAC/C;;CAED,iBAAiB;;CAEjB;IACI,aAA4C;IAC5C,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;CACxB;;CAED;IACI,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,aAAa;CAChB;;CAED;IACI,aAA4C;CAC/C;;CAED,gBAAgB;;CAEhB;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;CAC1B;;CAED;IACI,yFAAyF;IACzF,oBAAoB;IACpB,oBAAoB;CACvB;;CAED;IACI,iDAAiD;IACjD,uBAAsB;QAAtB,oBAAsB;YAAtB,sBAAsB;IACtB,aAAa;IACb,cAAc;CACjB;;CAED;IACI,YAAY;IACZ,+BAAuB;YAAvB,uBAAuB;IACvB,UAAU;IACV,kBAAoC;IACpC,yBAAgC;IAChC,0BAA6D;IAC7D,cAA6C;IAC7C,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,eAAe;CAClB;;CAED;IACI,wCAA+D;IAC/D,iBAAmF;CACtF;;CAED;IACI,oBAAiD;QAAjD,oBAAiD;YAAjD,gBAAiD;IACjD,gBAAgB;IAChB,iBAAmF;IACnF,kBAAqD;IACrD,kBAA+C;IAC/C,kBAAkB;IAClB,oBAAoC;IACpC,yBAAgC;IAChC,0BAA6D;IAC7D,oBAAoB;IACpB,mBAAmB;CACtB;;CAED;IACI,0BAAgC;IAChC,gEAAgE;IAChE,kBAAoC;IACpC,iBAAuF;IACvF,mCAA8C;YAA9C,2BAA8C;IAC9C,kBAAkB;CACrB;;CAED;IACI,mBAAmB;IACnB,UAAuC;IACvC,WAAwC;IACxC,YAAY;IACZ,YAAoD;IACpD,wBAA+C;IAC/C,oBAAmC;CACtC;;CAED;IACI,eAAe;CAClB;;CAED;IACI,kBAAoC;IACpC,yBAAgC;CACnC;;CAED;IACI,iBAAiB;CACpB;;CAED;IACI,yBAAyB;IACzB,iBAAiB,CAAC,WAAW;CAChC;;CAED;;;IAGI,kBAAqD;CACxD;;CAED,sBAAsB;;CAEtB;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;CACxB;;CAED;IACI,aAAyC;IACzC,gBAAgB;IAChB,yBAAgC;IAChC,0BAA0C;IAC1C,0BAAqE;IACrE,mBAA+F;IAC/F,kBAAkB;CACrB;;CAED;IACI,wBAA0C;IAC1C,yBAAgC;CACnC;;CAED;IACI,wBAA0C;IAC1C,0BAAgC;IAChC,gBAAgB;IAChB,oBAAoB;CACvB;;CAED;IACI,sBAAsB,EAAE,qCAAqC;IAC7D,sBAAsB;IACtB,8CAA8C;IAC9C,mBAAmB;IACnB,qBAAqB;IACrB,oCAAoC;IACpC,mCAAmC;CACtC;;CAED;IACI,sBAAsB,CAAC,oCAAoC;CAC9D;;CAED;IACI,cAA6C;IAC7C,wBAA0C;IAC1C,yBAAgC;IAChC,+BAA0E;IAC1E,gCAA2E;IAC3E,iCAA4E;IAC5E,eAAe;CAClB;;CAED;IACI,qBAAc;IAAd,qBAAc;IAAd,cAAc;IACd,6BAAuB;IAAvB,8BAAuB;QAAvB,2BAAuB;YAAvB,uBAAuB;IACvB,2BAAqB;QAArB,wBAAqB;YAArB,qBAAqB;CACxB;;CAED;IACI,iBAAiB;CACpB;;CAED;IACI,gBAAgB;CACnB;;CAID,iBAAiB;;CAEjB;IACI,gBAAuC;CAC1C;;CAED;IACI,0CAA0C;IAC1C,6BAAoB;QAApB,oBAAoB;IACpB,oBAAa;QAAb,qBAAa;YAAb,aAAa;IACb,qBAAe;QAAf,eAAe;IACf,kEAAkE;IAClE,kBAA6C;IAC7C,yEAAyE;IACzE,mBAAmB;CACtB","file":"controls.css","sourcesContent":["/* Copyright (c) Jupyter Development Team.\n * Distributed under the terms of the Modified BSD License.\n */\n\n /* We import all of these together in a single css file because the Webpack\nloader sees only one file at a time. This allows postcss to see the variable\ndefinitions when they are used. */\n\n@import \"./labvariables.css\";\n@import \"./widgets-base.css\";\n","/*-----------------------------------------------------------------------------\n| Copyright (c) Jupyter Development Team.\n| Distributed under the terms of the Modified BSD License.\n|----------------------------------------------------------------------------*/\n\n/*\nThis file is copied from the JupyterLab project to define default styling for\nwhen the widget styling is compiled down to eliminate CSS variables. We make one\nchange - we comment out the font import below.\n*/\n\n@import \"./materialcolors.css\";\n\n/*\nThe following CSS variables define the main, public API for styling JupyterLab.\nThese variables should be used by all plugins wherever possible. In other\nwords, plugins should not define custom colors, sizes, etc unless absolutely\nnecessary. This enables users to change the visual theme of JupyterLab\nby changing these variables.\n\nMany variables appear in an ordered sequence (0,1,2,3). These sequences\nare designed to work well together, so for example, `--jp-border-color1` should\nbe used with `--jp-layout-color1`. The numbers have the following meanings:\n\n* 0: super-primary, reserved for special emphasis\n* 1: primary, most important under normal situations\n* 2: secondary, next most important under normal situations\n* 3: tertiary, next most important under normal situations\n\nThroughout JupyterLab, we are mostly following principles from Google's\nMaterial Design when selecting colors. We are not, however, following\nall of MD as it is not optimized for dense, information rich UIs.\n*/\n\n\n/*\n * Optional monospace font for input/output prompt.\n */\n /* Commented out in ipywidgets since we don't need it. */\n/* @import url('https://fonts.googleapis.com/css?family=Roboto+Mono'); */\n\n/*\n * Added for compabitility with output area\n */\n:root {\n  --jp-icon-search: none;\n  --jp-ui-select-caret: none;\n}\n\n\n:root {\n\n  /* Borders\n\n  The following variables, specify the visual styling of borders in JupyterLab.\n   */\n\n  --jp-border-width: 1px;\n  --jp-border-color0: var(--md-grey-700);\n  --jp-border-color1: var(--md-grey-500);\n  --jp-border-color2: var(--md-grey-300);\n  --jp-border-color3: var(--md-grey-100);\n\n  /* UI Fonts\n\n  The UI font CSS variables are used for the typography all of the JupyterLab\n  user interface elements that are not directly user generated content.\n  */\n\n  --jp-ui-font-scale-factor: 1.2;\n  --jp-ui-font-size0: calc(var(--jp-ui-font-size1)/var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size1: 13px; /* Base font size */\n  --jp-ui-font-size2: calc(var(--jp-ui-font-size1)*var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size3: calc(var(--jp-ui-font-size2)*var(--jp-ui-font-scale-factor));\n  --jp-ui-icon-font-size: 14px; /* Ensures px perfect FontAwesome icons */\n  --jp-ui-font-family: \"Helvetica Neue\", Helvetica, Arial, sans-serif;\n\n  /* Use these font colors against the corresponding main layout colors.\n     In a light theme, these go from dark to light.\n  */\n\n  --jp-ui-font-color0: rgba(0,0,0,1.0);\n  --jp-ui-font-color1: rgba(0,0,0,0.8);\n  --jp-ui-font-color2: rgba(0,0,0,0.5);\n  --jp-ui-font-color3: rgba(0,0,0,0.3);\n\n  /* Use these against the brand/accent/warn/error colors.\n     These will typically go from light to darker, in both a dark and light theme\n   */\n\n  --jp-inverse-ui-font-color0: rgba(255,255,255,1);\n  --jp-inverse-ui-font-color1: rgba(255,255,255,1.0);\n  --jp-inverse-ui-font-color2: rgba(255,255,255,0.7);\n  --jp-inverse-ui-font-color3: rgba(255,255,255,0.5);\n\n  /* Content Fonts\n\n  Content font variables are used for typography of user generated content.\n  */\n\n  --jp-content-font-size: 13px;\n  --jp-content-line-height: 1.5;\n  --jp-content-font-color0: black;\n  --jp-content-font-color1: black;\n  --jp-content-font-color2: var(--md-grey-700);\n  --jp-content-font-color3: var(--md-grey-500);\n\n  --jp-ui-font-scale-factor: 1.2;\n  --jp-ui-font-size0: calc(var(--jp-ui-font-size1)/var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size1: 13px; /* Base font size */\n  --jp-ui-font-size2: calc(var(--jp-ui-font-size1)*var(--jp-ui-font-scale-factor));\n  --jp-ui-font-size3: calc(var(--jp-ui-font-size2)*var(--jp-ui-font-scale-factor));\n\n  --jp-code-font-size: 13px;\n  --jp-code-line-height: 1.307;\n  --jp-code-padding: 5px;\n  --jp-code-font-family: monospace;\n\n\n  /* Layout\n\n  The following are the main layout colors use in JupyterLab. In a light\n  theme these would go from light to dark.\n  */\n\n  --jp-layout-color0: white;\n  --jp-layout-color1: white;\n  --jp-layout-color2: var(--md-grey-200);\n  --jp-layout-color3: var(--md-grey-400);\n\n  /* Brand/accent */\n\n  --jp-brand-color0: var(--md-blue-700);\n  --jp-brand-color1: var(--md-blue-500);\n  --jp-brand-color2: var(--md-blue-300);\n  --jp-brand-color3: var(--md-blue-100);\n\n  --jp-accent-color0: var(--md-green-700);\n  --jp-accent-color1: var(--md-green-500);\n  --jp-accent-color2: var(--md-green-300);\n  --jp-accent-color3: var(--md-green-100);\n\n  /* State colors (warn, error, success, info) */\n\n  --jp-warn-color0: var(--md-orange-700);\n  --jp-warn-color1: var(--md-orange-500);\n  --jp-warn-color2: var(--md-orange-300);\n  --jp-warn-color3: var(--md-orange-100);\n\n  --jp-error-color0: var(--md-red-700);\n  --jp-error-color1: var(--md-red-500);\n  --jp-error-color2: var(--md-red-300);\n  --jp-error-color3: var(--md-red-100);\n\n  --jp-success-color0: var(--md-green-700);\n  --jp-success-color1: var(--md-green-500);\n  --jp-success-color2: var(--md-green-300);\n  --jp-success-color3: var(--md-green-100);\n\n  --jp-info-color0: var(--md-cyan-700);\n  --jp-info-color1: var(--md-cyan-500);\n  --jp-info-color2: var(--md-cyan-300);\n  --jp-info-color3: var(--md-cyan-100);\n\n  /* Cell specific styles */\n\n  --jp-cell-padding: 5px;\n  --jp-cell-editor-background: #f7f7f7;\n  --jp-cell-editor-border-color: #cfcfcf;\n  --jp-cell-editor-background-edit: var(--jp-ui-layout-color1);\n  --jp-cell-editor-border-color-edit: var(--jp-brand-color1);\n  --jp-cell-prompt-width: 100px;\n  --jp-cell-prompt-font-family: 'Roboto Mono', monospace;\n  --jp-cell-prompt-letter-spacing: 0px;\n  --jp-cell-prompt-opacity: 1.0;\n  --jp-cell-prompt-opacity-not-active: 0.4;\n  --jp-cell-prompt-font-color-not-active: var(--md-grey-700);\n  /* A custom blend of MD grey and blue 600\n   * See https://meyerweb.com/eric/tools/color-blend/#546E7A:1E88E5:5:hex */\n  --jp-cell-inprompt-font-color: #307FC1;\n  /* A custom blend of MD grey and orange 600\n   * https://meyerweb.com/eric/tools/color-blend/#546E7A:F4511E:5:hex */\n  --jp-cell-outprompt-font-color: #BF5B3D;\n\n  /* Notebook specific styles */\n\n  --jp-notebook-padding: 10px;\n  --jp-notebook-scroll-padding: 100px;\n\n  /* Console specific styles */\n\n  --jp-console-background: var(--md-grey-100);\n\n  /* Toolbar specific styles */\n\n  --jp-toolbar-border-color: var(--md-grey-400);\n  --jp-toolbar-micro-height: 8px;\n  --jp-toolbar-background: var(--jp-layout-color0);\n  --jp-toolbar-box-shadow: 0px 0px 2px 0px rgba(0,0,0,0.24);\n  --jp-toolbar-header-margin: 4px 4px 0px 4px;\n  --jp-toolbar-active-background: var(--md-grey-300);\n}\n","/**\n * The material design colors are adapted from google-material-color v1.2.6\n * https://github.com/danlevan/google-material-color\n * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/dist/palette.var.css\n *\n * The license for the material design color CSS variables is as follows (see\n * https://github.com/danlevan/google-material-color/blob/f67ca5f4028b2f1b34862f64b0ca67323f91b088/LICENSE)\n *\n * The MIT License (MIT)\n *\n * Copyright (c) 2014 Dan Le Van\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to deal\n * in the Software without restriction, including without limitation the rights\n * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n * copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n * SOFTWARE.\n */\n:root {\n  --md-red-50: #FFEBEE;\n  --md-red-100: #FFCDD2;\n  --md-red-200: #EF9A9A;\n  --md-red-300: #E57373;\n  --md-red-400: #EF5350;\n  --md-red-500: #F44336;\n  --md-red-600: #E53935;\n  --md-red-700: #D32F2F;\n  --md-red-800: #C62828;\n  --md-red-900: #B71C1C;\n  --md-red-A100: #FF8A80;\n  --md-red-A200: #FF5252;\n  --md-red-A400: #FF1744;\n  --md-red-A700: #D50000;\n\n  --md-pink-50: #FCE4EC;\n  --md-pink-100: #F8BBD0;\n  --md-pink-200: #F48FB1;\n  --md-pink-300: #F06292;\n  --md-pink-400: #EC407A;\n  --md-pink-500: #E91E63;\n  --md-pink-600: #D81B60;\n  --md-pink-700: #C2185B;\n  --md-pink-800: #AD1457;\n  --md-pink-900: #880E4F;\n  --md-pink-A100: #FF80AB;\n  --md-pink-A200: #FF4081;\n  --md-pink-A400: #F50057;\n  --md-pink-A700: #C51162;\n\n  --md-purple-50: #F3E5F5;\n  --md-purple-100: #E1BEE7;\n  --md-purple-200: #CE93D8;\n  --md-purple-300: #BA68C8;\n  --md-purple-400: #AB47BC;\n  --md-purple-500: #9C27B0;\n  --md-purple-600: #8E24AA;\n  --md-purple-700: #7B1FA2;\n  --md-purple-800: #6A1B9A;\n  --md-purple-900: #4A148C;\n  --md-purple-A100: #EA80FC;\n  --md-purple-A200: #E040FB;\n  --md-purple-A400: #D500F9;\n  --md-purple-A700: #AA00FF;\n\n  --md-deep-purple-50: #EDE7F6;\n  --md-deep-purple-100: #D1C4E9;\n  --md-deep-purple-200: #B39DDB;\n  --md-deep-purple-300: #9575CD;\n  --md-deep-purple-400: #7E57C2;\n  --md-deep-purple-500: #673AB7;\n  --md-deep-purple-600: #5E35B1;\n  --md-deep-purple-700: #512DA8;\n  --md-deep-purple-800: #4527A0;\n  --md-deep-purple-900: #311B92;\n  --md-deep-purple-A100: #B388FF;\n  --md-deep-purple-A200: #7C4DFF;\n  --md-deep-purple-A400: #651FFF;\n  --md-deep-purple-A700: #6200EA;\n\n  --md-indigo-50: #E8EAF6;\n  --md-indigo-100: #C5CAE9;\n  --md-indigo-200: #9FA8DA;\n  --md-indigo-300: #7986CB;\n  --md-indigo-400: #5C6BC0;\n  --md-indigo-500: #3F51B5;\n  --md-indigo-600: #3949AB;\n  --md-indigo-700: #303F9F;\n  --md-indigo-800: #283593;\n  --md-indigo-900: #1A237E;\n  --md-indigo-A100: #8C9EFF;\n  --md-indigo-A200: #536DFE;\n  --md-indigo-A400: #3D5AFE;\n  --md-indigo-A700: #304FFE;\n\n  --md-blue-50: #E3F2FD;\n  --md-blue-100: #BBDEFB;\n  --md-blue-200: #90CAF9;\n  --md-blue-300: #64B5F6;\n  --md-blue-400: #42A5F5;\n  --md-blue-500: #2196F3;\n  --md-blue-600: #1E88E5;\n  --md-blue-700: #1976D2;\n  --md-blue-800: #1565C0;\n  --md-blue-900: #0D47A1;\n  --md-blue-A100: #82B1FF;\n  --md-blue-A200: #448AFF;\n  --md-blue-A400: #2979FF;\n  --md-blue-A700: #2962FF;\n\n  --md-light-blue-50: #E1F5FE;\n  --md-light-blue-100: #B3E5FC;\n  --md-light-blue-200: #81D4FA;\n  --md-light-blue-300: #4FC3F7;\n  --md-light-blue-400: #29B6F6;\n  --md-light-blue-500: #03A9F4;\n  --md-light-blue-600: #039BE5;\n  --md-light-blue-700: #0288D1;\n  --md-light-blue-800: #0277BD;\n  --md-light-blue-900: #01579B;\n  --md-light-blue-A100: #80D8FF;\n  --md-light-blue-A200: #40C4FF;\n  --md-light-blue-A400: #00B0FF;\n  --md-light-blue-A700: #0091EA;\n\n  --md-cyan-50: #E0F7FA;\n  --md-cyan-100: #B2EBF2;\n  --md-cyan-200: #80DEEA;\n  --md-cyan-300: #4DD0E1;\n  --md-cyan-400: #26C6DA;\n  --md-cyan-500: #00BCD4;\n  --md-cyan-600: #00ACC1;\n  --md-cyan-700: #0097A7;\n  --md-cyan-800: #00838F;\n  --md-cyan-900: #006064;\n  --md-cyan-A100: #84FFFF;\n  --md-cyan-A200: #18FFFF;\n  --md-cyan-A400: #00E5FF;\n  --md-cyan-A700: #00B8D4;\n\n  --md-teal-50: #E0F2F1;\n  --md-teal-100: #B2DFDB;\n  --md-teal-200: #80CBC4;\n  --md-teal-300: #4DB6AC;\n  --md-teal-400: #26A69A;\n  --md-teal-500: #009688;\n  --md-teal-600: #00897B;\n  --md-teal-700: #00796B;\n  --md-teal-800: #00695C;\n  --md-teal-900: #004D40;\n  --md-teal-A100: #A7FFEB;\n  --md-teal-A200: #64FFDA;\n  --md-teal-A400: #1DE9B6;\n  --md-teal-A700: #00BFA5;\n\n  --md-green-50: #E8F5E9;\n  --md-green-100: #C8E6C9;\n  --md-green-200: #A5D6A7;\n  --md-green-300: #81C784;\n  --md-green-400: #66BB6A;\n  --md-green-500: #4CAF50;\n  --md-green-600: #43A047;\n  --md-green-700: #388E3C;\n  --md-green-800: #2E7D32;\n  --md-green-900: #1B5E20;\n  --md-green-A100: #B9F6CA;\n  --md-green-A200: #69F0AE;\n  --md-green-A400: #00E676;\n  --md-green-A700: #00C853;\n\n  --md-light-green-50: #F1F8E9;\n  --md-light-green-100: #DCEDC8;\n  --md-light-green-200: #C5E1A5;\n  --md-light-green-300: #AED581;\n  --md-light-green-400: #9CCC65;\n  --md-light-green-500: #8BC34A;\n  --md-light-green-600: #7CB342;\n  --md-light-green-700: #689F38;\n  --md-light-green-800: #558B2F;\n  --md-light-green-900: #33691E;\n  --md-light-green-A100: #CCFF90;\n  --md-light-green-A200: #B2FF59;\n  --md-light-green-A400: #76FF03;\n  --md-light-green-A700: #64DD17;\n\n  --md-lime-50: #F9FBE7;\n  --md-lime-100: #F0F4C3;\n  --md-lime-200: #E6EE9C;\n  --md-lime-300: #DCE775;\n  --md-lime-400: #D4E157;\n  --md-lime-500: #CDDC39;\n  --md-lime-600: #C0CA33;\n  --md-lime-700: #AFB42B;\n  --md-lime-800: #9E9D24;\n  --md-lime-900: #827717;\n  --md-lime-A100: #F4FF81;\n  --md-lime-A200: #EEFF41;\n  --md-lime-A400: #C6FF00;\n  --md-lime-A700: #AEEA00;\n\n  --md-yellow-50: #FFFDE7;\n  --md-yellow-100: #FFF9C4;\n  --md-yellow-200: #FFF59D;\n  --md-yellow-300: #FFF176;\n  --md-yellow-400: #FFEE58;\n  --md-yellow-500: #FFEB3B;\n  --md-yellow-600: #FDD835;\n  --md-yellow-700: #FBC02D;\n  --md-yellow-800: #F9A825;\n  --md-yellow-900: #F57F17;\n  --md-yellow-A100: #FFFF8D;\n  --md-yellow-A200: #FFFF00;\n  --md-yellow-A400: #FFEA00;\n  --md-yellow-A700: #FFD600;\n\n  --md-amber-50: #FFF8E1;\n  --md-amber-100: #FFECB3;\n  --md-amber-200: #FFE082;\n  --md-amber-300: #FFD54F;\n  --md-amber-400: #FFCA28;\n  --md-amber-500: #FFC107;\n  --md-amber-600: #FFB300;\n  --md-amber-700: #FFA000;\n  --md-amber-800: #FF8F00;\n  --md-amber-900: #FF6F00;\n  --md-amber-A100: #FFE57F;\n  --md-amber-A200: #FFD740;\n  --md-amber-A400: #FFC400;\n  --md-amber-A700: #FFAB00;\n\n  --md-orange-50: #FFF3E0;\n  --md-orange-100: #FFE0B2;\n  --md-orange-200: #FFCC80;\n  --md-orange-300: #FFB74D;\n  --md-orange-400: #FFA726;\n  --md-orange-500: #FF9800;\n  --md-orange-600: #FB8C00;\n  --md-orange-700: #F57C00;\n  --md-orange-800: #EF6C00;\n  --md-orange-900: #E65100;\n  --md-orange-A100: #FFD180;\n  --md-orange-A200: #FFAB40;\n  --md-orange-A400: #FF9100;\n  --md-orange-A700: #FF6D00;\n\n  --md-deep-orange-50: #FBE9E7;\n  --md-deep-orange-100: #FFCCBC;\n  --md-deep-orange-200: #FFAB91;\n  --md-deep-orange-300: #FF8A65;\n  --md-deep-orange-400: #FF7043;\n  --md-deep-orange-500: #FF5722;\n  --md-deep-orange-600: #F4511E;\n  --md-deep-orange-700: #E64A19;\n  --md-deep-orange-800: #D84315;\n  --md-deep-orange-900: #BF360C;\n  --md-deep-orange-A100: #FF9E80;\n  --md-deep-orange-A200: #FF6E40;\n  --md-deep-orange-A400: #FF3D00;\n  --md-deep-orange-A700: #DD2C00;\n\n  --md-brown-50: #EFEBE9;\n  --md-brown-100: #D7CCC8;\n  --md-brown-200: #BCAAA4;\n  --md-brown-300: #A1887F;\n  --md-brown-400: #8D6E63;\n  --md-brown-500: #795548;\n  --md-brown-600: #6D4C41;\n  --md-brown-700: #5D4037;\n  --md-brown-800: #4E342E;\n  --md-brown-900: #3E2723;\n\n  --md-grey-50: #FAFAFA;\n  --md-grey-100: #F5F5F5;\n  --md-grey-200: #EEEEEE;\n  --md-grey-300: #E0E0E0;\n  --md-grey-400: #BDBDBD;\n  --md-grey-500: #9E9E9E;\n  --md-grey-600: #757575;\n  --md-grey-700: #616161;\n  --md-grey-800: #424242;\n  --md-grey-900: #212121;\n\n  --md-blue-grey-50: #ECEFF1;\n  --md-blue-grey-100: #CFD8DC;\n  --md-blue-grey-200: #B0BEC5;\n  --md-blue-grey-300: #90A4AE;\n  --md-blue-grey-400: #78909C;\n  --md-blue-grey-500: #607D8B;\n  --md-blue-grey-600: #546E7A;\n  --md-blue-grey-700: #455A64;\n  --md-blue-grey-800: #37474F;\n  --md-blue-grey-900: #263238;\n}","/* Copyright (c) Jupyter Development Team.\n * Distributed under the terms of the Modified BSD License.\n */\n\n/*\n * We assume that the CSS variables in\n * https://github.com/jupyterlab/jupyterlab/blob/master/src/default-theme/variables.css\n * have been defined.\n */\n\n@import \"./phosphor.css\";\n\n:root {\n    --jp-widgets-color: var(--jp-content-font-color1);\n    --jp-widgets-label-color: var(--jp-widgets-color);\n    --jp-widgets-readout-color: var(--jp-widgets-color);\n    --jp-widgets-font-size: var(--jp-ui-font-size1);\n    --jp-widgets-margin: 2px;\n    --jp-widgets-inline-height: 28px;\n    --jp-widgets-inline-width: 300px;\n    --jp-widgets-inline-width-short: calc(var(--jp-widgets-inline-width) / 2 - var(--jp-widgets-margin));\n    --jp-widgets-inline-width-tiny: calc(var(--jp-widgets-inline-width-short) / 2 - var(--jp-widgets-margin));\n    --jp-widgets-inline-margin: 4px; /* margin between inline elements */\n    --jp-widgets-inline-label-width: 80px;\n    --jp-widgets-border-width: var(--jp-border-width);\n    --jp-widgets-vertical-height: 200px;\n    --jp-widgets-horizontal-tab-height: 24px;\n    --jp-widgets-horizontal-tab-width: 144px;\n    --jp-widgets-horizontal-tab-top-border: 2px;\n    --jp-widgets-progress-thickness: 20px;\n    --jp-widgets-container-padding: 15px;\n    --jp-widgets-input-padding: 4px;\n    --jp-widgets-radio-item-height-adjustment: 8px;\n    --jp-widgets-radio-item-height: calc(var(--jp-widgets-inline-height) - var(--jp-widgets-radio-item-height-adjustment));\n    --jp-widgets-slider-track-thickness: 4px;\n    --jp-widgets-slider-border-width: var(--jp-widgets-border-width);\n    --jp-widgets-slider-handle-size: 16px;\n    --jp-widgets-slider-handle-border-color: var(--jp-border-color1);\n    --jp-widgets-slider-handle-background-color: var(--jp-layout-color1);\n    --jp-widgets-slider-active-handle-color: var(--jp-brand-color1);\n    --jp-widgets-menu-item-height: 24px;\n    --jp-widgets-dropdown-arrow: url(\"data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0idXRmLTgiPz4KPCEtLSBHZW5lcmF0b3I6IEFkb2JlIElsbHVzdHJhdG9yIDE5LjIuMSwgU1ZHIEV4cG9ydCBQbHVnLUluIC4gU1ZHIFZlcnNpb246IDYuMDAgQnVpbGQgMCkgIC0tPgo8c3ZnIHZlcnNpb249IjEuMSIgaWQ9IkxheWVyXzEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgeG1sbnM6eGxpbms9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkveGxpbmsiIHg9IjBweCIgeT0iMHB4IgoJIHZpZXdCb3g9IjAgMCAxOCAxOCIgc3R5bGU9ImVuYWJsZS1iYWNrZ3JvdW5kOm5ldyAwIDAgMTggMTg7IiB4bWw6c3BhY2U9InByZXNlcnZlIj4KPHN0eWxlIHR5cGU9InRleHQvY3NzIj4KCS5zdDB7ZmlsbDpub25lO30KPC9zdHlsZT4KPHBhdGggZD0iTTUuMiw1LjlMOSw5LjdsMy44LTMuOGwxLjIsMS4ybC00LjksNWwtNC45LTVMNS4yLDUuOXoiLz4KPHBhdGggY2xhc3M9InN0MCIgZD0iTTAtMC42aDE4djE4SDBWLTAuNnoiLz4KPC9zdmc+Cg\");\n    --jp-widgets-input-color: var(--jp-ui-font-color1);\n    --jp-widgets-input-background-color: var(--jp-layout-color1);\n    --jp-widgets-input-border-color: var(--jp-border-color1);\n    --jp-widgets-input-focus-border-color: var(--jp-brand-color2);\n    --jp-widgets-input-border-width: var(--jp-widgets-border-width);\n    --jp-widgets-disabled-opacity: 0.6;\n\n    /* From Material Design Lite */\n    --md-shadow-key-umbra-opacity: 0.2;\n    --md-shadow-key-penumbra-opacity: 0.14;\n    --md-shadow-ambient-shadow-opacity: 0.12;\n}\n\n.jupyter-widgets {\n    margin: var(--jp-widgets-margin);\n    box-sizing: border-box;\n    color: var(--jp-widgets-color);\n    overflow: visible;\n}\n\n.jupyter-widgets.jupyter-widgets-disconnected::before {\n    line-height: var(--jp-widgets-inline-height);\n    height: var(--jp-widgets-inline-height);\n}\n\n.jp-Output-result > .jupyter-widgets {\n    margin-left: 0;\n    margin-right: 0;\n}\n\n/* vbox and hbox */\n\n.widget-inline-hbox {\n    /* Horizontal widgets */\n    box-sizing: border-box;\n    display: flex;\n    flex-direction: row;\n    align-items: baseline;\n}\n\n.widget-inline-vbox {\n    /* Vertical Widgets */\n    box-sizing: border-box;\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n}\n\n.widget-box {\n    box-sizing: border-box;\n    display: flex;\n    margin: 0;\n    overflow: auto;\n}\n\n.widget-gridbox {\n    box-sizing: border-box;\n    display: grid;\n    margin: 0;\n    overflow: auto;\n}\n\n.widget-hbox {\n    flex-direction: row;\n}\n\n.widget-vbox {\n    flex-direction: column;\n}\n\n/* General Button Styling */\n\n.jupyter-button {\n    padding-left: 10px;\n    padding-right: 10px;\n    padding-top: 0px;\n    padding-bottom: 0px;\n    display: inline-block;\n    white-space: nowrap;\n    overflow: hidden;\n    text-overflow: ellipsis;\n    text-align: center;\n    font-size: var(--jp-widgets-font-size);\n    cursor: pointer;\n\n    height: var(--jp-widgets-inline-height);\n    border: 0px solid;\n    line-height: var(--jp-widgets-inline-height);\n    box-shadow: none;\n\n    color: var(--jp-ui-font-color1);\n    background-color: var(--jp-layout-color2);\n    border-color: var(--jp-border-color2);\n    border: none;\n}\n\n.jupyter-button i.fa {\n    margin-right: var(--jp-widgets-inline-margin);\n    pointer-events: none;\n}\n\n.jupyter-button:empty:before {\n    content: \"\\200b\"; /* zero-width space */\n}\n\n.jupyter-widgets.jupyter-button:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n.jupyter-button i.fa.center {\n    margin-right: 0;\n}\n\n.jupyter-button:hover:enabled, .jupyter-button:focus:enabled {\n    /* MD Lite 2dp shadow */\n    box-shadow: 0 2px 2px 0 rgba(0, 0, 0, var(--md-shadow-key-penumbra-opacity)),\n                0 3px 1px -2px rgba(0, 0, 0, var(--md-shadow-key-umbra-opacity)),\n                0 1px 5px 0 rgba(0, 0, 0, var(--md-shadow-ambient-shadow-opacity));\n}\n\n.jupyter-button:active, .jupyter-button.mod-active {\n    /* MD Lite 4dp shadow */\n    box-shadow: 0 4px 5px 0 rgba(0, 0, 0, var(--md-shadow-key-penumbra-opacity)),\n                0 1px 10px 0 rgba(0, 0, 0, var(--md-shadow-ambient-shadow-opacity)),\n                0 2px 4px -1px rgba(0, 0, 0, var(--md-shadow-key-umbra-opacity));\n    color: var(--jp-ui-font-color1);\n    background-color: var(--jp-layout-color3);\n}\n\n.jupyter-button:focus:enabled {\n    outline: 1px solid var(--jp-widgets-input-focus-border-color);\n}\n\n/* Button \"Primary\" Styling */\n\n.jupyter-button.mod-primary {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-brand-color1);\n}\n\n.jupyter-button.mod-primary.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-brand-color0);\n}\n\n.jupyter-button.mod-primary:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-brand-color0);\n}\n\n/* Button \"Success\" Styling */\n\n.jupyter-button.mod-success {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-success-color1);\n}\n\n.jupyter-button.mod-success.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-success-color0);\n }\n\n.jupyter-button.mod-success:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-success-color0);\n }\n\n /* Button \"Info\" Styling */\n\n.jupyter-button.mod-info {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-info-color1);\n}\n\n.jupyter-button.mod-info.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-info-color0);\n}\n\n.jupyter-button.mod-info:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-info-color0);\n}\n\n/* Button \"Warning\" Styling */\n\n.jupyter-button.mod-warning {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-warn-color1);\n}\n\n.jupyter-button.mod-warning.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-warn-color0);\n}\n\n.jupyter-button.mod-warning:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-warn-color0);\n}\n\n/* Button \"Danger\" Styling */\n\n.jupyter-button.mod-danger {\n    color: var(--jp-inverse-ui-font-color1);\n    background-color: var(--jp-error-color1);\n}\n\n.jupyter-button.mod-danger.mod-active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-error-color0);\n}\n\n.jupyter-button.mod-danger:active {\n    color: var(--jp-inverse-ui-font-color0);\n    background-color: var(--jp-error-color0);\n}\n\n/* Widget Button*/\n\n.widget-button, .widget-toggle-button {\n    width: var(--jp-widgets-inline-width-short);\n}\n\n/* Widget Label Styling */\n\n/* Override Bootstrap label css */\n.jupyter-widgets label {\n    margin-bottom: initial;\n}\n\n.widget-label-basic {\n    /* Basic Label */\n    color: var(--jp-widgets-label-color);\n    font-size: var(--jp-widgets-font-size);\n    overflow: hidden;\n    text-overflow: ellipsis;\n    white-space: nowrap;\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-label {\n    /* Label */\n    color: var(--jp-widgets-label-color);\n    font-size: var(--jp-widgets-font-size);\n    overflow: hidden;\n    text-overflow: ellipsis;\n    white-space: nowrap;\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-inline-hbox .widget-label {\n    /* Horizontal Widget Label */\n    color: var(--jp-widgets-label-color);\n    text-align: right;\n    margin-right: calc( var(--jp-widgets-inline-margin) * 2 );\n    width: var(--jp-widgets-inline-label-width);\n    flex-shrink: 0;\n}\n\n.widget-inline-vbox .widget-label {\n    /* Vertical Widget Label */\n    color: var(--jp-widgets-label-color);\n    text-align: center;\n    line-height: var(--jp-widgets-inline-height);\n}\n\n/* Widget Readout Styling */\n\n.widget-readout {\n    color: var(--jp-widgets-readout-color);\n    font-size: var(--jp-widgets-font-size);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    overflow: hidden;\n    white-space: nowrap;\n    text-align: center;\n}\n\n.widget-readout.overflow {\n    /* Overflowing Readout */\n\n    /* From Material Design Lite\n        shadow-key-umbra-opacity: 0.2;\n        shadow-key-penumbra-opacity: 0.14;\n        shadow-ambient-shadow-opacity: 0.12;\n     */\n    -webkit-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.2),\n                        0 3px 1px -2px rgba(0, 0, 0, 0.14),\n                        0 1px 5px 0 rgba(0, 0, 0, 0.12);\n\n    -moz-box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.2),\n                     0 3px 1px -2px rgba(0, 0, 0, 0.14),\n                     0 1px 5px 0 rgba(0, 0, 0, 0.12);\n\n    box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.2),\n                0 3px 1px -2px rgba(0, 0, 0, 0.14),\n                0 1px 5px 0 rgba(0, 0, 0, 0.12);\n}\n\n.widget-inline-hbox .widget-readout {\n    /* Horizontal Readout */\n    text-align: center;\n    max-width: var(--jp-widgets-inline-width-short);\n    min-width: var(--jp-widgets-inline-width-tiny);\n    margin-left: var(--jp-widgets-inline-margin);\n}\n\n.widget-inline-vbox .widget-readout {\n    /* Vertical Readout */\n    margin-top: var(--jp-widgets-inline-margin);\n    /* as wide as the widget */\n    width: inherit;\n}\n\n/* Widget Checkbox Styling */\n\n.widget-checkbox {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-checkbox input[type=\"checkbox\"] {\n    margin: 0px calc( var(--jp-widgets-inline-margin) * 2 ) 0px 0px;\n    line-height: var(--jp-widgets-inline-height);\n    font-size: large;\n    flex-grow: 1;\n    flex-shrink: 0;\n    align-self: center;\n}\n\n/* Widget Valid Styling */\n\n.widget-valid {\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    width: var(--jp-widgets-inline-width-short);\n    font-size: var(--jp-widgets-font-size);\n}\n\n.widget-valid i:before {\n    line-height: var(--jp-widgets-inline-height);\n    margin-right: var(--jp-widgets-inline-margin);\n    margin-left: var(--jp-widgets-inline-margin);\n\n    /* from the fa class in FontAwesome: https://github.com/FortAwesome/Font-Awesome/blob/49100c7c3a7b58d50baa71efef11af41a66b03d3/css/font-awesome.css#L14 */\n    display: inline-block;\n    font: normal normal normal 14px/1 FontAwesome;\n    font-size: inherit;\n    text-rendering: auto;\n    -webkit-font-smoothing: antialiased;\n    -moz-osx-font-smoothing: grayscale;\n}\n\n.widget-valid.mod-valid i:before {\n    content: \"\\f00c\";\n    color: green;\n}\n\n.widget-valid.mod-invalid i:before {\n    content: \"\\f00d\";\n    color: red;\n}\n\n.widget-valid.mod-valid .widget-valid-readout {\n    display: none;\n}\n\n/* Widget Text and TextArea Stying */\n\n.widget-textarea, .widget-text {\n    width: var(--jp-widgets-inline-width);\n}\n\n.widget-text input[type=\"text\"], .widget-text input[type=\"number\"]{\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-text input[type=\"text\"]:disabled, .widget-text input[type=\"number\"]:disabled, .widget-textarea textarea:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n.widget-text input[type=\"text\"], .widget-text input[type=\"number\"], .widget-textarea textarea {\n    box-sizing: border-box;\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    padding: var(--jp-widgets-input-padding) calc( var(--jp-widgets-input-padding) *  2 );\n    flex-grow: 1;\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    flex-shrink: 1;\n    outline: none !important;\n}\n\n.widget-textarea textarea {\n    height: inherit;\n    width: inherit;\n}\n\n.widget-text input:focus, .widget-textarea textarea:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n/* Widget Slider */\n\n.widget-slider .ui-slider {\n    /* Slider Track */\n    border: var(--jp-widgets-slider-border-width) solid var(--jp-layout-color3);\n    background: var(--jp-layout-color3);\n    box-sizing: border-box;\n    position: relative;\n    border-radius: 0px;\n}\n\n.widget-slider .ui-slider .ui-slider-handle {\n    /* Slider Handle */\n    outline: none !important; /* focused slider handles are colored - see below */\n    position: absolute;\n    background-color: var(--jp-widgets-slider-handle-background-color);\n    border: var(--jp-widgets-slider-border-width) solid var(--jp-widgets-slider-handle-border-color);\n    box-sizing: border-box;\n    z-index: 1;\n    background-image: none; /* Override jquery-ui */\n}\n\n/* Override jquery-ui */\n.widget-slider .ui-slider .ui-slider-handle:hover, .widget-slider .ui-slider .ui-slider-handle:focus {\n    background-color: var(--jp-widgets-slider-active-handle-color);\n    border: var(--jp-widgets-slider-border-width) solid var(--jp-widgets-slider-active-handle-color);\n}\n\n.widget-slider .ui-slider .ui-slider-handle:active {\n    background-color: var(--jp-widgets-slider-active-handle-color);\n    border-color: var(--jp-widgets-slider-active-handle-color);\n    z-index: 2;\n    transform: scale(1.2);\n}\n\n.widget-slider  .ui-slider .ui-slider-range {\n    /* Interval between the two specified value of a double slider */\n    position: absolute;\n    background: var(--jp-widgets-slider-active-handle-color);\n    z-index: 0;\n}\n\n/* Shapes of Slider Handles */\n\n.widget-hslider .ui-slider .ui-slider-handle {\n    width: var(--jp-widgets-slider-handle-size);\n    height: var(--jp-widgets-slider-handle-size);\n    margin-top: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-handle-size)) / 2 - var(--jp-widgets-slider-border-width));\n    margin-left: calc(var(--jp-widgets-slider-handle-size) / -2 + var(--jp-widgets-slider-border-width));\n    border-radius: 50%;\n    top: 0;\n}\n\n.widget-vslider .ui-slider .ui-slider-handle {\n    width: var(--jp-widgets-slider-handle-size);\n    height: var(--jp-widgets-slider-handle-size);\n    margin-bottom: calc(var(--jp-widgets-slider-handle-size) / -2 + var(--jp-widgets-slider-border-width));\n    margin-left: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-handle-size)) / 2 - var(--jp-widgets-slider-border-width));\n    border-radius: 50%;\n    left: 0;\n}\n\n.widget-hslider .ui-slider .ui-slider-range {\n    height: calc( var(--jp-widgets-slider-track-thickness) * 2 );\n    margin-top: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-track-thickness) * 2 ) / 2 - var(--jp-widgets-slider-border-width));\n}\n\n.widget-vslider .ui-slider .ui-slider-range {\n    width: calc( var(--jp-widgets-slider-track-thickness) * 2 );\n    margin-left: calc((var(--jp-widgets-slider-track-thickness) - var(--jp-widgets-slider-track-thickness) * 2 ) / 2 - var(--jp-widgets-slider-border-width));\n}\n\n/* Horizontal Slider */\n\n.widget-hslider {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n\n    /* Override the align-items baseline. This way, the description and readout\n    still seem to align their baseline properly, and we don't have to have\n    align-self: stretch in the .slider-container. */\n    align-items: center;\n}\n\n.widgets-slider .slider-container {\n    overflow: visible;\n}\n\n.widget-hslider .slider-container {\n    height: var(--jp-widgets-inline-height);\n    margin-left: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    margin-right: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n}\n\n.widget-hslider .ui-slider {\n    /* Inner, invisible slide div */\n    height: var(--jp-widgets-slider-track-thickness);\n    margin-top: calc((var(--jp-widgets-inline-height) - var(--jp-widgets-slider-track-thickness)) / 2);\n    width: 100%;\n}\n\n/* Vertical Slider */\n\n.widget-vbox .widget-label {\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-vslider {\n    /* Vertical Slider */\n    height: var(--jp-widgets-vertical-height);\n    width: var(--jp-widgets-inline-width-tiny);\n}\n\n.widget-vslider .slider-container {\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n    margin-left: auto;\n    margin-right: auto;\n    margin-bottom: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    margin-top: calc(var(--jp-widgets-slider-handle-size) / 2 - 2 * var(--jp-widgets-slider-border-width));\n    display: flex;\n    flex-direction: column;\n}\n\n.widget-vslider .ui-slider-vertical {\n    /* Inner, invisible slide div */\n    width: var(--jp-widgets-slider-track-thickness);\n    flex-grow: 1;\n    margin-left: auto;\n    margin-right: auto;\n}\n\n/* Widget Progress Styling */\n\n.progress-bar {\n    -webkit-transition: none;\n    -moz-transition: none;\n    -ms-transition: none;\n    -o-transition: none;\n    transition: none;\n}\n\n.progress-bar {\n    height: var(--jp-widgets-inline-height);\n}\n\n.progress-bar {\n    background-color: var(--jp-brand-color1);\n}\n\n.progress-bar-success {\n    background-color: var(--jp-success-color1);\n}\n\n.progress-bar-info {\n    background-color: var(--jp-info-color1);\n}\n\n.progress-bar-warning {\n    background-color: var(--jp-warn-color1);\n}\n\n.progress-bar-danger {\n    background-color: var(--jp-error-color1);\n}\n\n.progress {\n    background-color: var(--jp-layout-color2);\n    border: none;\n    box-shadow: none;\n}\n\n/* Horisontal Progress */\n\n.widget-hprogress {\n    /* Progress Bar */\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    width: var(--jp-widgets-inline-width);\n    align-items: center;\n\n}\n\n.widget-hprogress .progress {\n    flex-grow: 1;\n    margin-top: var(--jp-widgets-input-padding);\n    margin-bottom: var(--jp-widgets-input-padding);\n    align-self: stretch;\n    /* Override bootstrap style */\n    height: initial;\n}\n\n/* Vertical Progress */\n\n.widget-vprogress {\n    height: var(--jp-widgets-vertical-height);\n    width: var(--jp-widgets-inline-width-tiny);\n}\n\n.widget-vprogress .progress {\n    flex-grow: 1;\n    width: var(--jp-widgets-progress-thickness);\n    margin-left: auto;\n    margin-right: auto;\n    margin-bottom: 0;\n}\n\n/* Select Widget Styling */\n\n.widget-dropdown {\n    height: var(--jp-widgets-inline-height);\n    width: var(--jp-widgets-inline-width);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-dropdown > select {\n    padding-right: 20px;\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    border-radius: 0;\n    height: inherit;\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    box-sizing: border-box;\n    outline: none !important;\n    box-shadow: none;\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    vertical-align: top;\n    padding-left: calc( var(--jp-widgets-input-padding) * 2);\n\tappearance: none;\n\t-webkit-appearance: none;\n\t-moz-appearance: none;\n    background-repeat: no-repeat;\n\tbackground-size: 20px;\n\tbackground-position: right center;\n    background-image: var(--jp-widgets-dropdown-arrow);\n}\n.widget-dropdown > select:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.widget-dropdown > select:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* To disable the dotted border in Firefox around select controls.\n   See http://stackoverflow.com/a/18853002 */\n.widget-dropdown > select:-moz-focusring {\n    color: transparent;\n    text-shadow: 0 0 0 #000;\n}\n\n/* Select and SelectMultiple */\n\n.widget-select {\n    width: var(--jp-widgets-inline-width);\n    line-height: var(--jp-widgets-inline-height);\n\n    /* Because Firefox defines the baseline of a select as the bottom of the\n    control, we align the entire control to the top and add padding to the\n    select to get an approximate first line baseline alignment. */\n    align-items: flex-start;\n}\n\n.widget-select > select {\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    flex: 1 1 var(--jp-widgets-inline-width-short);\n    outline: none !important;\n    overflow: auto;\n    height: inherit;\n\n    /* Because Firefox defines the baseline of a select as the bottom of the\n    control, we align the entire control to the top and add padding to the\n    select to get an approximate first line baseline alignment. */\n    padding-top: 5px;\n}\n\n.widget-select > select:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.wiget-select > select > option {\n    padding-left: var(--jp-widgets-input-padding);\n    line-height: var(--jp-widgets-inline-height);\n    /* line-height doesn't work on some browsers for select options */\n    padding-top: calc(var(--jp-widgets-inline-height)-var(--jp-widgets-font-size)/2);\n    padding-bottom: calc(var(--jp-widgets-inline-height)-var(--jp-widgets-font-size)/2);\n}\n\n\n\n/* Toggle Buttons Styling */\n\n.widget-toggle-buttons {\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-toggle-buttons .widget-toggle-button {\n    margin-left: var(--jp-widgets-margin);\n    margin-right: var(--jp-widgets-margin);\n}\n\n.widget-toggle-buttons .jupyter-button:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Radio Buttons Styling */\n\n.widget-radio {\n    width: var(--jp-widgets-inline-width);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-radio-box {\n    display: flex;\n    flex-direction: column;\n    align-items: stretch;\n    box-sizing: border-box;\n    flex-grow: 1;\n    margin-bottom: var(--jp-widgets-radio-item-height-adjustment);\n}\n\n.widget-radio-box label {\n    height: var(--jp-widgets-radio-item-height);\n    line-height: var(--jp-widgets-radio-item-height);\n    font-size: var(--jp-widgets-font-size);\n}\n\n.widget-radio-box input {\n    height: var(--jp-widgets-radio-item-height);\n    line-height: var(--jp-widgets-radio-item-height);\n    margin: 0 calc( var(--jp-widgets-input-padding) * 2 ) 0 1px;\n    float: left;\n}\n\n/* Color Picker Styling */\n\n.widget-colorpicker {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-colorpicker > .widget-colorpicker-input {\n    flex-grow: 1;\n    flex-shrink: 1;\n    min-width: var(--jp-widgets-inline-width-tiny);\n}\n\n.widget-colorpicker input[type=\"color\"] {\n    width: var(--jp-widgets-inline-height);\n    height: var(--jp-widgets-inline-height);\n    padding: 0 2px; /* make the color square actually square on Chrome on OS X */\n    background: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    border-left: none;\n    flex-grow: 0;\n    flex-shrink: 0;\n    box-sizing: border-box;\n    align-self: stretch;\n    outline: none !important;\n}\n\n.widget-colorpicker.concise input[type=\"color\"] {\n    border-left: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n}\n\n.widget-colorpicker input[type=\"color\"]:focus, .widget-colorpicker input[type=\"text\"]:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.widget-colorpicker input[type=\"text\"] {\n    flex-grow: 1;\n    outline: none !important;\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n    background: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    font-size: var(--jp-widgets-font-size);\n    padding: var(--jp-widgets-input-padding) calc( var(--jp-widgets-input-padding) *  2 );\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    flex-shrink: 1;\n    box-sizing: border-box;\n}\n\n.widget-colorpicker input[type=\"text\"]:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Date Picker Styling */\n\n.widget-datepicker {\n    width: var(--jp-widgets-inline-width);\n    height: var(--jp-widgets-inline-height);\n    line-height: var(--jp-widgets-inline-height);\n}\n\n.widget-datepicker input[type=\"date\"] {\n    flex-grow: 1;\n    flex-shrink: 1;\n    min-width: 0; /* This makes it possible for the flexbox to shrink this input */\n    outline: none !important;\n    height: var(--jp-widgets-inline-height);\n    border: var(--jp-widgets-input-border-width) solid var(--jp-widgets-input-border-color);\n    background-color: var(--jp-widgets-input-background-color);\n    color: var(--jp-widgets-input-color);\n    font-size: var(--jp-widgets-font-size);\n    padding: var(--jp-widgets-input-padding) calc( var(--jp-widgets-input-padding) *  2 );\n    box-sizing: border-box;\n}\n\n.widget-datepicker input[type=\"date\"]:focus {\n    border-color: var(--jp-widgets-input-focus-border-color);\n}\n\n.widget-datepicker input[type=\"date\"]:invalid {\n    border-color: var(--jp-warn-color1);\n}\n\n.widget-datepicker input[type=\"date\"]:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Play Widget */\n\n.widget-play {\n    width: var(--jp-widgets-inline-width-short);\n    display: flex;\n    align-items: stretch;\n}\n\n.widget-play .jupyter-button {\n    flex-grow: 1;\n    height: auto;\n}\n\n.widget-play .jupyter-button:disabled {\n    opacity: var(--jp-widgets-disabled-opacity);\n}\n\n/* Tab Widget */\n\n.jupyter-widgets.widget-tab {\n    display: flex;\n    flex-direction: column;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar {\n    /* Necessary so that a tab can be shifted down to overlay the border of the box below. */\n    overflow-x: visible;\n    overflow-y: visible;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {\n    /* Make sure that the tab grows from bottom up */\n    align-items: flex-end;\n    min-width: 0;\n    min-height: 0;\n}\n\n.jupyter-widgets.widget-tab > .widget-tab-contents {\n    width: 100%;\n    box-sizing: border-box;\n    margin: 0;\n    background: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n    border: var(--jp-border-width) solid var(--jp-border-color1);\n    padding: var(--jp-widgets-container-padding);\n    flex-grow: 1;\n    overflow: auto;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar {\n    font: var(--jp-widgets-font-size) Helvetica, Arial, sans-serif;\n    min-height: calc(var(--jp-widgets-horizontal-tab-height) + var(--jp-border-width));\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {\n    flex: 0 1 var(--jp-widgets-horizontal-tab-width);\n    min-width: 35px;\n    min-height: calc(var(--jp-widgets-horizontal-tab-height) + var(--jp-border-width));\n    line-height: var(--jp-widgets-horizontal-tab-height);\n    margin-left: calc(-1 * var(--jp-border-width));\n    padding: 0px 10px;\n    background: var(--jp-layout-color2);\n    color: var(--jp-ui-font-color2);\n    border: var(--jp-border-width) solid var(--jp-border-color1);\n    border-bottom: none;\n    position: relative;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current {\n    color: var(--jp-ui-font-color0);\n    /* We want the background to match the tab content background */\n    background: var(--jp-layout-color1);\n    min-height: calc(var(--jp-widgets-horizontal-tab-height) + 2 * var(--jp-border-width));\n    transform: translateY(var(--jp-border-width));\n    overflow: visible;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-current:before {\n    position: absolute;\n    top: calc(-1 * var(--jp-border-width));\n    left: calc(-1 * var(--jp-border-width));\n    content: '';\n    height: var(--jp-widgets-horizontal-tab-top-border);\n    width: calc(100% + 2 * var(--jp-border-width));\n    background: var(--jp-brand-color1);\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:first-child {\n    margin-left: 0;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab:hover:not(.p-mod-current) {\n    background: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon {\n    margin-left: 4px;\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-mod-closable > .p-TabBar-tabCloseIcon:before {\n    font-family: FontAwesome;\n    content: '\\f00d'; /* close */\n}\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel,\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {\n    line-height: var(--jp-widgets-horizontal-tab-height);\n}\n\n/* Accordion Widget */\n\n.p-Collapse {\n    display: flex;\n    flex-direction: column;\n    align-items: stretch;\n}\n\n.p-Collapse-header {\n    padding: var(--jp-widgets-input-padding);\n    cursor: pointer;\n    color: var(--jp-ui-font-color2);\n    background-color: var(--jp-layout-color2);\n    border: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    padding: calc(var(--jp-widgets-container-padding) * 2 / 3) var(--jp-widgets-container-padding);\n    font-weight: bold;\n}\n\n.p-Collapse-header:hover {\n    background-color: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n}\n\n.p-Collapse-open > .p-Collapse-header {\n    background-color: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color0);\n    cursor: default;\n    border-bottom: none;\n}\n\n.p-Collapse .p-Collapse-header::before {\n    content: '\\f0da\\00A0';  /* caret-right, non-breaking space */\n    display: inline-block;\n    font: normal normal normal 14px/1 FontAwesome;\n    font-size: inherit;\n    text-rendering: auto;\n    -webkit-font-smoothing: antialiased;\n    -moz-osx-font-smoothing: grayscale;\n}\n\n.p-Collapse-open > .p-Collapse-header::before {\n    content: '\\f0d7\\00A0'; /* caret-down, non-breaking space */\n}\n\n.p-Collapse-contents {\n    padding: var(--jp-widgets-container-padding);\n    background-color: var(--jp-layout-color1);\n    color: var(--jp-ui-font-color1);\n    border-left: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    border-right: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    border-bottom: var(--jp-widgets-border-width) solid var(--jp-border-color1);\n    overflow: auto;\n}\n\n.p-Accordion {\n    display: flex;\n    flex-direction: column;\n    align-items: stretch;\n}\n\n.p-Accordion .p-Collapse {\n    margin-bottom: 0;\n}\n\n.p-Accordion .p-Collapse + .p-Collapse {\n    margin-top: 4px;\n}\n\n\n\n/* HTML widget */\n\n.widget-html, .widget-htmlmath {\n    font-size: var(--jp-widgets-font-size);\n}\n\n.widget-html > .widget-html-content, .widget-htmlmath > .widget-html-content {\n    /* Fill out the area in the HTML widget */\n    align-self: stretch;\n    flex-grow: 1;\n    flex-shrink: 1;\n    /* Makes sure the baseline is still aligned with other elements */\n    line-height: var(--jp-widgets-inline-height);\n    /* Make it possible to have absolutely-positioned elements in the html */\n    position: relative;\n}\n","/* This file has code derived from PhosphorJS CSS files, as noted below. The license for this PhosphorJS code is:\n\nCopyright (c) 2014-2017, PhosphorJS Contributors\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n*/\n\n/*\n * The following section is derived from https://github.com/phosphorjs/phosphor/blob/23b9d075ebc5b73ab148b6ebfc20af97f85714c4/packages/widgets/style/tabbar.css \n * We've scoped the rules so that they are consistent with exactly our code.\n */\n\n.jupyter-widgets.widget-tab > .p-TabBar {\n  display: flex;\n  -webkit-user-select: none;\n  -moz-user-select: none;\n  -ms-user-select: none;\n  user-select: none;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] {\n  flex-direction: row;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] {\n  flex-direction: column;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar > .p-TabBar-content {\n  margin: 0;\n  padding: 0;\n  display: flex;\n  flex: 1 1 auto;\n  list-style-type: none;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='horizontal'] > .p-TabBar-content {\n  flex-direction: row;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar[data-orientation='vertical'] > .p-TabBar-content {\n  flex-direction: column;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {\n  display: flex;\n  flex-direction: row;\n  box-sizing: border-box;\n  overflow: hidden;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabIcon,\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabCloseIcon {\n  flex: 0 0 auto;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tabLabel {\n  flex: 1 1 auto;\n  overflow: hidden;\n  white-space: nowrap;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab.p-mod-hidden {\n  display: none !important;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab {\n  position: relative;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='horizontal'] .p-TabBar-tab {\n  left: 0;\n  transition: left 150ms ease;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging[data-orientation='vertical'] .p-TabBar-tab {\n  top: 0;\n  transition: top 150ms ease;\n}\n\n\n.jupyter-widgets.widget-tab > .p-TabBar.p-mod-dragging .p-TabBar-tab.p-mod-dragging {\n  transition: none;\n}\n\n/* End tabbar.css */\n"]} */",
+ "ok": true,
+ "headers": [
+ [
+ "content-type",
+ "text/css"
+ ]
+ ],
+ "status": 200,
+ "status_text": ""
+ }
+ },
+ "base_uri": "https://localhost:8080/",
+ "height": 464
+ }
+ },
+ "source": [
+ "y_names = ['shape', 'scale', 'orientation', 'posX', 'posY']\n",
+ "y_shapes = np.array((3,6,40,32,32))\n",
+ "img_dict = {}\n",
+ "\n",
+ "for i, img in enumerate(imgs_sampled):\n",
+ " img_dict[tuple(labels_sampled[i])] = img\n",
+ "\n",
+ "def find_in_dataset(shape, scale, orient, posX, posY):\n",
+ " fig = plt.figure()\n",
+ " img = img_dict[(0, shape, scale, orient, posX, posY)]\n",
+ " plt.imshow(img.reshape(64,64), cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ "\n",
+ "interact(find_in_dataset, \n",
+ " shape=widgets.IntSlider(min=0, max=2, step=1, value=0),\n",
+ " scale=widgets.IntSlider(min=0, max=5, step=1, value=0),\n",
+ " orient=widgets.IntSlider(min=0, max=39, step=1, value=0),\n",
+ " posX=widgets.IntSlider(min=0, max=31, step=1, value=0),\n",
+ " posY=widgets.IntSlider(min=0, max=31, step=1, value=0))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e787b532e7d84bf7a7d614eba2f930ff",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "interactive(children=(IntSlider(value=0, description='shape', max=2), IntSlider(value=0, description='scale', …"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 15
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1EueLb0OzSDm",
+ "colab_type": "code",
+ "outputId": "e6365256-b1f9-4fc6-c1b3-959c4dc381cf",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 311
+ }
+ },
+ "source": [
+ "def get_specific_data(args=dict(), cuda=False):\n",
+ " '''\n",
+ " use this function to get examples of data with specific class labels\n",
+ " inputs: \n",
+ " args - dictionary whose keys can include {shape, scale, orientation,\n",
+ " posX, posY} and values can include any integers less than the \n",
+ " corresponding size of that label dimension\n",
+ " cuda - bool to indicate whether the output should be placed on GPU\n",
+ " '''\n",
+ " names_dict = {'shape': 1, 'scale': 2, 'orientation': 3, 'posX': 4, 'posY': 5}\n",
+ " selected_ind = np.ones(imgs.shape[0], dtype=bool)\n",
+ " for k,v in args.items():\n",
+ " col_id = names_dict[k]\n",
+ " selected_ind = np.bitwise_and(selected_ind, labels[:, col_id] == v)\n",
+ " ind = np.random.choice(np.arange(imgs.shape[0])[selected_ind])\n",
+ " x = torch.from_numpy(imgs[ind].reshape(1,64**2).astype(np.float32))\n",
+ " y = torch.from_numpy(labels[ind].reshape(1,6).astype(np.float32))\n",
+ " if not cuda:\n",
+ " return x,y\n",
+ " x = x.cuda()\n",
+ " y = y.cuda()\n",
+ " return x,y\n",
+ "\n",
+ "def plot_image(x):\n",
+ " \"\"\"\n",
+ " helper to plot dSprites images\n",
+ " \"\"\"\n",
+ " x = x.cpu()\n",
+ " plt.figure()\n",
+ " plt.imshow(x.reshape(64,64), interpolation='nearest', cmap='Greys_r')\n",
+ " plt.axis('off')\n",
+ "\n",
+ "def see_specific_image(args=dict(), verbose=True):\n",
+ " '''\n",
+ " use this function to get examples of data with specific class labels\n",
+ " inputs: \n",
+ " args - dictionary whose keys can include {shape, scale, orientation,\n",
+ " posX, posY} and values can include any integers less than the \n",
+ " corresponding size of that label dimension\n",
+ " verbose - bool to indicate whether the full class label should be written \n",
+ " as the title of the plot\n",
+ " '''\n",
+ " x,y = get_specific_data(args, cuda=False)\n",
+ " plot_image(x)\n",
+ " if verbose:\n",
+ " string = ''\n",
+ " for i, s in enumerate(['Shape', 'Scale', 'Orientation', 'PosX', 'PosY']):\n",
+ " string += '%s: %d, ' % (s, int(y[0][i+1]))\n",
+ " if i == 2:\n",
+ " string = string[:-2] + '\\n'\n",
+ " plt.title(string[:-2], fontsize=12)\n",
+ " \n",
+ "def compare_reconstruction(original, recon):\n",
+ " \"\"\"\n",
+ " compare two images side by side\n",
+ " inputs:\n",
+ " original - array for original image\n",
+ " recon - array for recon image\n",
+ " \"\"\"\n",
+ " fig = plt.figure()\n",
+ " ax0 = fig.add_subplot(121)\n",
+ " plt.imshow(original.cpu().reshape(64,64), cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " plt.title('original')\n",
+ " ax1 = fig.add_subplot(122)\n",
+ " plt.imshow(recon.cpu().reshape(64,64), cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " plt.title('reconstruction')\n",
+ " \n",
+ "def compare_to_density(original, recons):\n",
+ " \"\"\"\n",
+ " compare two images side by side\n",
+ " inputs:\n",
+ " original - array for original image\n",
+ " recon - array of multiple recon images\n",
+ " \"\"\"\n",
+ " fig = plt.figure()\n",
+ " ax0 = fig.add_subplot(121)\n",
+ " plt.imshow(original.cpu().reshape(64,64), cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " plt.title('original')\n",
+ " ax1 = fig.add_subplot(122)\n",
+ " plt.imshow(torch.mean(recons.cpu(), 0).reshape(64,64), cmap='Greys_r', interpolation='nearest')\n",
+ " plt.axis('off')\n",
+ " plt.title('reconstructions')\n",
+ "\n",
+ " \n",
+ "see_specific_image()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ0AAAEmCAYAAAB4ecX9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEYFJREFUeJzt3Xu0HWV5x/HvQ0KAYMiFUCoQwkIU\nBBbFG8hSaorYNipKW61FBKJ4iS4WlVLRskChFFl1VRAviOIlUKJ4Q6FRFFEjFQQvy9qaEjVows1Q\nQhLwBLxg3v7xvgeGvc4+2U9Ccsjh+1lrr5w9756Zd2bP/Gbm3fNOopSCJA1qm7GugKSti6EhKcXQ\nkJRiaEhKMTQkpRgaklKeMKEREfMi4jtjXY+tQUQsjojXj3U9NlZEXBwRZ451PQAi4vCI+OlY1+Ox\nNK5CIyKeHxE3RsR9EbE6Im6IiOeMdb02JCJeHRErImJdRHwpImYkxj09In4ZEUMRcUdEfGZz1jUr\nIg6OiB9GxAPt34MT40ZEvC0ifh4RD0bEbRFxXkRsN9p4pZT5pZRzHoO6z4mIO5LjlIjYp1OX/yyl\n7LupdRlgvk+LiKsi4p627X8tIvbt+cwpEbEyIu6PiE9saD32M25CIyJ2AhYBHwBmALsDZwO/Hct6\nbUhEHAB8BDgO2BV4ALhowHFPaOMdWUp5EvBs4BubqappETEJuAq4HJgOXApc1YYP4v3AG4HjgSnA\nXOCFwGdHmeeETanzVmwacDWwL3U7+h513QMQEX8BvIO6/mYDe1P3j7xSyrh4UXeYtaOUzwO+A/wb\nsAb4JTC3U/5a4Bbg18AvgDd1yuYAdwCnA6uA5cCxnfLt2nRvA+4GLgZ2GLDe7wY+1Xn/FOB3wJQB\nxv0g8L5RymcAnwTuasv8pTZ8OjVg72nDFwF7dMZbDLy+8/51bd2sAb4GzB5w2f4cuBOIzrDbgL8c\nYNynAn8ADukZPot6IDiivV8AfBj4CrAOOLIN+5fOOC8F/gtYC9wIHNQpWw78I/DfwH3AZ4DtgR2B\nB4H1wFB77QYcAny3TetX7TuY1KZ1PVBaPYaAVw1vO535Pb2t37XAEuBlnbIFwIeAL7ft8GbgKRu5\nP8xoddm5vf8U8O5O+QuBlRs17c29M2+pF7ATcC/1aDYXmN5TPg/4PfAGYALw5rYzRSt/CXWHDeAF\n1CP+M1vZHOAh4HxqQLygbRj7tvILqCk/g3pE/A/gvM681wLP71Pvq4C39wwbAp41wDK/BlgNvI0a\nmhN6yr/cdoLpwLbAC9rwnYG/ASa3+n6OFiitfDEtNICXA8vaxj4ROAO4sfPZRcA7+tTvFOCanmGL\ngFMHWLb5wIo+Zd8eXr9tR7sPeB71zHl7OqEBPAP4P+DQ9r2fQA2K7Vr5cupRebf2/d0CzO9873f0\nzPtZwHPbutirff6tnfIC7NN5//A02newjHrwmQQcQQ2HfTvLci81mCYCC4ErBlnXI6yjo4Ffdd7/\nGHhV5/1MOqGS2tfGemd/LF9tw15APSt4iLoj79rK5gHLOp+d3FbaH/eZ1peAv+988Q8BO3bKPwuc\nSQ2ZdXSOCMBhwC8HrPM3hjfSzrA7gTkDjn8scF2rw720AAKeTD1KTh9gGgcDazrvF/NIaFwDnNgp\n24YaqLMHmO6Z3Y2+DVsInDXAuGcAN/UpuwK4pP29ALisp3wBj4TGh4Fzesp/yiMBuhx4TafsPcDF\nne/9jg3U863AFzvvRwuNw4GVwDad8k8Pr49W7491yl4MLN2I/WCPtg0d0xl2K50zPGqAFWCv7PTH\nTZsGQCnlllLKvFLKHsCB1KPH+zofWdn57APtzycBRMTciLipNSKtpX5hMzvjrimlrOu8X9Gmvws1\ngH4YEWvbuF9twwcxRD1L6tqJegTaoFLKwlLKkdRr2vnAOe36dRawupSypneciJgcER9pja/3U0+r\np/VpD5gNXNhZttXUoNx9My/bKmrwjeTJrXzY7aNMZzZw6nD92zLMon53w1Z2/n6Atk2MpDU4Lhpu\nUKReXs7s9/keuwG3l1LWd4at4NHrcuC69KnfLsC1wEWllE93inq/i+G/B9rOusZVaHSVUpZSk/vA\nDX22tSJ/gdousWspZRr1Gjk6H5seETt23u9JvbxZRb32PaCUMq29ppbaMDmIJcCfdOqyN/US6GcD\njg9AKeX3pZTPUa/ND6TuSDMiYtoIHz+V2mB2aCllJ+BPh2c/wmdvp7bvTOu8diil3DhAtZYAB0VE\nd7oHteEb8k1gVkQc0h0YEbOolwfdBt/RumrfDpzbU//JPTtUPyNN98PAUuCpbd2dzsjrbSR3UZep\nu9/tST0r2GQRMZ0aGFeXUs7tKX7Udtb+vruUcm92PuMmNCJiv4g4NSL2aO9nAccANw0w+iTqjnoP\n8FBEzKU24vU6OyImRcTh1Ma1z7WjxiXABRHxR23eu7ej/SAWAke13/N3BP4ZuLKU8us2rQURsaDP\nMs+LiJdExJSI2KbV+wDg5lLKr6iXFhdFxPSI2DYihsNhCjXo1rafd981Sv0uBv6p/cpDREyNiFcO\nuGyLqY2ZJ0fEdhFxUhv+zU79l480YinlZ23eCyPiuRExodXhC8B1pZTrBqzDJcD8iDi0/YS74/A6\nG2Dcu4GdI2JqZ9gU4H5gKCL2o7aN9Y6zd5/p3Uw9ezitfR9zgKOol1ubpP16+DXghlLKO0b4yGXA\niRGxfzuQnEE9qKaNm9CgnmYdCtwcEeuoYfET6lF1VG0HPZnaTrEGeDW1PaRrZSu7i7qjz29nMwBv\npzZw3dROWa+jHskBaPdQHN5n3kuolxULqQ12U4C3dD4yC7ihT9Xvpx7pbqM2tr4HeHMpZfgmtuOo\njb9L27Tf2oa/D9iBepZ0E/VyakSllC8C/wpc0ZbtJ9SG5uFluyYiTu8z7u+oDXLHt/q9Dji6Dd/Q\nsgGcBHyM+pPtUKvnYmoj7kBKKT+gNn5/kPr9LaO2bw0y7lJqm8Mv2qXNbtRfWl5N3d4uoTY0d50F\nXNo+/7c90/sdNSTmUtf9RcDxne1oVKOta+CvgOcAr23b2/Brzzbvr1K3j29Rt5cVjH6w6F+P1iii\nUbQjwuWtrWRLzncStdX7oFLK77fkvLeEiLiW2th8y1jXRYObONYVUH/tyPT0sa7H5lJKGekSUI9z\n4+nyRNIW4OWJpBTPNCSlGBqSUgyNx7GIWB61S/hQRNzd7tlI3SHYM71ntG7R+3SGPav9PLjXgNM4\nJyL+JyIeioizkvNfHBG/acuzKiKujIh+d31uaFoREddHxLt6hh8fEbdGxOQBpnFBRCyLiF9HxC0R\ncWynbNeoj1m4t62fGyPisI2p63hjaDz+HdXuLn0mtVPaGRs7oVLKj6j3K1zSdrptgU8A7yylLB9w\nMsuA06id4TbGSW15nka99f2CjZlIqY1xrwdO6dx4tgvwXmq/mQdGG78ZonZUnEq9h+RDnTtQ76f2\nfN6F2uHvvcDV8cTtev8wQ2MrUUq5k3qH54EAEbFbRFwdta/Msoh4w/BnI+KQiPhBO6u4OyLO70zq\nbGrfjTdSbwwbogbJoPW4tJRyDRvRZ6FnOqupd3cOL8/UiLgs6kNkVkTEGdFut46IfSLi21EfrrQq\n2oOG2l2j5wIfb599P/CFUsq3BqzDmaWUn5ZS1pdSvkvtNn9YK3twuIx6m/h6ah+Tqf2n+MTgfRpb\niai3xb8YuLINuoJ6d+ZuwH7A1yPi1lLKN4ELgQtLKf/eLmce7n9TSvltRJxIPVPYhvq8ivWd+VzU\nPte9K3VzLM9M6p2dP2qDPkDdIfemdt2/lvq8io8D57T3f0a95f/ZnUmdD7wC+HwbfkBnHscBp5RS\nnjlAfSa38c/vGb6EelY0kdr7dXVyUcefbLdYX1vuRe22PUS9BXsF9bbjHai3X/+BzoN6gPOABe3v\n66lnFDP7THcq9TbmGzahbpczQBf3nnEWU/terKV20lpIPf2fQH3w0P6dz74JWNz+vgz4KJ0HBfVM\n9wBq57KXb+SyRFueRX3Kt6c+guC4sd4mHg8vL08e/44utWfm7FLKW0opD1LPLlaX1qmt6XaxPpF6\ndFwaEd+PiJf2TPO91AfZ7BERf7e5F6DHyW15di+lHFtKuYd62r8tdRmGdZfnNOqO/b2IWBIRr+tO\nsNT+OzBY79mRnE9dX8eMVFhK+U0pZSFw5nD7yROZlydbp7uo3d6ndILj4S7WpZSfA8e06/y/Bj4f\nETuXUtZFxJHAy4D9qU+I+mREXFvG9rR7FbVj3Wzgf9uw7vKspHY6IyKeD1wXEdeXUpZt6owj4lzq\no+/m9ITwSCZRL582NpzGBc80tkKllNupjXbnRcT2EXEQ9ezicoCIeE1E7FJqW8XaNtr6qF3vP0q9\nzl9VSvkK8HUSv2BE7dK9PXXbmdjmP6GV7RX1adx7JZfnD9QexudG7eY/G/iHzvK8MtojD6g9VQu1\nYXKTRP1vDl4BvKg3NCPisIh4XlveHVrv0hnA9zd1vlu9sb4+8tX/RW3TOLJP2R7UZ0aupj7KbX6n\n7HJqV/gh6lHx6Db8QuArPdOZ2T77ovb+Ytrj7vrMdwF1p+2+5rWyw1udt+0z7mI6DyzuKZve6n0P\n9cE576Q9Fo/apfvOtjy3Am8cYfxHPWavDTsB+HGf+U1o4/yWRx4cPASc1sqPoD7QaKit42/R5zmv\nT7SXfU/0mImIM4B7SikfGeu6aPMxNCSl2KYhKcXQkJRiaEhKMTQkpWyVN3dFhK230mZWShnx/3Px\nTENSiqEhKcXQkJRiaEhKMTQkpRgaklIMDUkphoakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1J\nKYaGpBRDQ1KKoSEpxdCQlGJoSEoxNCSlGBqSUgwNSSmGhqQUQ0NSiqEhKcXQkJRiaEhKMTQkpRga\nklIMDUkphoakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRDQ1KKoSEpxdCQlGJoSEox\nNCSlGBqSUgwNSSmGhqQUQ0NSiqEhKcXQkJRiaEhKMTQkpRgaklIMDUkphoakFENDUoqhISnF0JCU\nYmhISjE0JKUYGpJSDA1JKYaGpBRDQ1KKoSEpxdCQlGJoSEoxNCSlGBqSUgwNSSmGhqQUQ0NSiqEh\nKcXQkJRiaEhKMTQkpRgaklIMDUkphoakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRD\nQ1KKoSEpxdCQlGJoSEoxNCSlGBqSUgwNSSmGhqQUQ0NSiqEhKcXQkJRiaEhKMTQkpRgaklIMDUkp\nhoakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRDQ1KKoSEpxdCQlGJoSEoxNCSlTBzr\nCoxXpZSxrgIAETHWVdA445mGpBRDQ1KKoSEpxdCQlGJoSEoxNCSlGBqSUgwNSSmGhqQUQ0NSiqEh\nKcXQkJRiaEhKMTQkpRgaklIMDUkphoakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRD\nQ1KK/y3jZtL97xAfL/9Fo/RY8ExDUoqhISnF0JCUYpvGVqLbRiKNJc80JKUYGpJSvDzZAry00Hji\nmYakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRDQ1KKoSEpxdCQlGJoSEoxNCSlGBqS\nUgwNSSmGhqQUQ0NSiqEhKcXQkJRiaEhKMTQkpRgaklIMDUkphoakFENDUoqhISnF0JCUYmhISjE0\nJKUYGpJSDA1JKYaGpBRDQ1KKoSEpxdCQlGJoSEoxNCSlGBqSUgwNSSmGhqQUQ0NSiqEhKcXQkJRi\naEhKMTQkpRgaklIMDUkphoakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRDQ1KKoSEp\nxdCQlGJoSEoxNCSlGBqSUgwNSSlRShnrOkjainimISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRD\nQ1KKoSEpxdCQlGJoSEoxNCSlGBqSUgwNSSmGhqQUQ0NSiqEhKcXQkJRiaEhKMTQkpRgaklIMDUkp\nhoakFENDUoqhISnF0JCUYmhISjE0JKUYGpJSDA1JKYaGpBRDQ1KKoSEpxdCQlGJoSEoxNCSlGBqS\nUgwNSSn/D1nKrt3W//HwAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "OkB6V4LjESxV",
+ "colab_type": "code",
+ "outputId": "a923d9d6-233e-4066-e9ba-f60200c1a05a",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ }
+ },
+ "source": [
+ "label_dims = vae.label_shape\n",
+ "label_dim_offsets = np.cumsum(label_dims)\n",
+ "label_dim_offsets"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "array([ 1, 4, 10, 50, 82, 114])"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 17
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "VXhaIdcBEHnS",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class SCM():\n",
+ " \"\"\"\n",
+ " Structural causal model\n",
+ " \n",
+ " args: \n",
+ " vae: instance of vae\n",
+ " mu: loc of q(z|x) given by the vae encoder\n",
+ " sigma: scale of q(z|x) given by the vae encoder\n",
+ " \n",
+ " \"\"\"\n",
+ " def __init__(self, vae, mu, sigma):\n",
+ " \"\"\"\n",
+ " Constructor\n",
+ " \n",
+ " Intializes :\n",
+ " image dimensions - 4096(64*64), \n",
+ " z dimensions: size of the tensor representing the latent random variable z, \n",
+ " label dimensions: 114 labels y that correspond to an image(one hot encoded)\n",
+ " f(x) = p(x|y,z)\n",
+ " Noise variables in the model N_#\n",
+ " \"\"\"\n",
+ " self.vae = vae\n",
+ " self.image_dim = vae.image_dim\n",
+ " self.z_dim = vae.z_dim\n",
+ " # these are used for f_X\n",
+ " self.label_dims = vae.label_shape\n",
+ " \n",
+ " def f_X(Y, Z, N):\n",
+ " \"\"\"\n",
+ " Generating one hots for the factors\n",
+ " \"\"\" \n",
+ " zs = Z.cuda()\n",
+ " # convert the labels to one hot\n",
+ " ys = [torch.tensor([0])]\n",
+ " ys.append(torch.nn.functional.one_hot(torch.round(Y[0]).to(torch.long), int(self.label_dims[1])))\n",
+ " ys.append(torch.nn.functional.one_hot(torch.round(Y[1]).to(torch.long), int(self.label_dims[2])))\n",
+ " ys.append(torch.nn.functional.one_hot(torch.round(Y[2]).to(torch.long), int(self.label_dims[3])))\n",
+ " ys.append(torch.nn.functional.one_hot(torch.round(Y[3]).to(torch.long), int(self.label_dims[4])))\n",
+ " ys.append(torch.nn.functional.one_hot(torch.round(Y[4]).to(torch.long), int(self.label_dims[5])))\n",
+ " ys = torch.cat(ys).to(torch.float32).reshape(1,-1).cuda()\n",
+ " p = vae.decoder.forward(zs, ys)\n",
+ " return (N < p.cpu()).type(torch.float)\n",
+ " \n",
+ " def f_Y(N):\n",
+ " \"\"\"\n",
+ " Gumbel distribution - to model the distribution of the maximum of a number of samples\n",
+ " m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])).sample() # sample from Gumbel distribution with loc=1, scale=2\n",
+ " tensor([ 1.0124])\n",
+ " \n",
+ " https://pytorch.org/docs/stable/_modules/torch/distributions/gumbel.html\n",
+ " \"\"\"\n",
+ "# m = torch.distributions.gumbel.Gumbel(torch.zeros(N.size(0)), torch.ones(N.size(0)))\n",
+ " beta = 12\n",
+ " indices = torch.tensor(np.arange(N.size(0))).to(torch.float32)\n",
+ " smax = nn.functional.softmax(beta*N)\n",
+ " argmax_ind = torch.sum(smax*indices)\n",
+ " return argmax_ind\n",
+ " \n",
+ " def f_Z(N):\n",
+ " \"\"\"\n",
+ " Z ~ Normal(mu, sigma) \n",
+ " \"\"\"\n",
+ " return N * sigma + mu\n",
+ " \n",
+ " def model(noise): \n",
+ " \"\"\"\n",
+ " The model corresponds to a generative process\n",
+ " \n",
+ " args: noise variables\n",
+ " return: X(image), Y(labels), Z(latents) \n",
+ " \"\"\"\n",
+ " N_X = pyro.sample( 'N_X', noise['N_X'].to_event(1) )\n",
+ " # denoted using the index in the sequence \n",
+ " # that they are stored in as vae.label_names:\n",
+ " # ['shape', 'scale', 'orientation', 'posX', 'posY']\n",
+ " N_Y_1 = pyro.sample( 'N_Y_1', noise['N_Y_1'].to_event(1) )\n",
+ " N_Y_2 = pyro.sample( 'N_Y_2', noise['N_Y_2'].to_event(1) )\n",
+ " N_Y_3 = pyro.sample( 'N_Y_3', noise['N_Y_3'].to_event(1) )\n",
+ " N_Y_4 = pyro.sample( 'N_Y_4', noise['N_Y_4'].to_event(1) )\n",
+ " N_Y_5 = pyro.sample( 'N_Y_5', noise['N_Y_5'].to_event(1) )\n",
+ " \n",
+ " # Z ~ Normal(Nx_mu, Nx_sigma) \n",
+ " N_Z = pyro.sample( 'N_Z', noise['N_Z'].to_event(1) )\n",
+ " Z = pyro.sample('Z', dist.Normal( f_Z( N_Z ), 1e-1).to_event(1) )\n",
+ " \n",
+ " # Y ~ Gumbel max of Ny \n",
+ "# Y_1_mu = f_Y(N_Y_1)\n",
+ "# Y_2_mu = f_Y(N_Y_2)\n",
+ "# Y_3_mu = f_Y(N_Y_3)\n",
+ "# Y_4_mu = f_Y(N_Y_4)\n",
+ "# Y_5_mu = f_Y(N_Y_5)\n",
+ " \n",
+ " Y_1 = pyro.sample('Y_1', dist.Normal( f_Y(N_Y_1), 1e-2) )\n",
+ " Y_2 = pyro.sample('Y_2', dist.Normal( f_Y(N_Y_2), 1e-1) )\n",
+ " Y_3 = pyro.sample('Y_3', dist.Normal( f_Y(N_Y_3), 1e-1) )\n",
+ " Y_4 = pyro.sample('Y_4', dist.Normal( f_Y(N_Y_4), 1e-1) )\n",
+ " Y_5 = pyro.sample('Y_5', dist.Normal( f_Y(N_Y_5), 1e-1) )\n",
+ " \n",
+ "# Y_mu = (Y_1_mu, Y_2_mu, Y_3_mu, Y_4_mu, Y_5_mu)\n",
+ " \n",
+ " # X ~ p(x|y,z) = bernoulli(loc(y,z)) \n",
+ " X = pyro.sample('X', dist.Normal( f_X( (Y_1, Y_2, Y_3,Y_4,Y_5), Z, N_X ), 1e-2).to_event(1))\n",
+ " \n",
+ " # return noise and variables\n",
+ " noise_samples = N_X, (N_Y_1, N_Y_2, N_Y_3, N_Y_4, N_Y_5), N_Z\n",
+ " variable_samples = X, (Y_1, Y_2, Y_3, Y_4, Y_5), Z\n",
+ " return variable_samples, noise_samples\n",
+ " \n",
+ " self.model = model\n",
+ " #Initialize all noise variables in the model \n",
+ " self.init_noise = {\n",
+ " 'N_X' : dist.Uniform(torch.zeros(vae.image_dim), torch.ones(vae.image_dim)),\n",
+ " 'N_Z' : dist.Normal(torch.zeros(vae.z_dim), torch.ones(vae.z_dim)),\n",
+ " 'N_Y_1' : dist.Uniform(torch.zeros(self.label_dims[1]),torch.ones(self.label_dims[1])),\n",
+ " 'N_Y_2' : dist.Uniform(torch.zeros(self.label_dims[2]),torch.ones(self.label_dims[2])),\n",
+ " 'N_Y_3' : dist.Uniform(torch.zeros(self.label_dims[3]),torch.ones(self.label_dims[3])),\n",
+ " 'N_Y_4' : dist.Uniform(torch.zeros(self.label_dims[4]),torch.ones(self.label_dims[4])),\n",
+ " 'N_Y_5' : dist.Uniform(torch.zeros(self.label_dims[5]),torch.ones(self.label_dims[5])) \n",
+ " }\n",
+ " \n",
+ " def update_noise_svi(self, obs_data, intervened_model=None):\n",
+ " \"\"\"\n",
+ " Use svi to find out the mu, sigma of the distributionsfor the \n",
+ " condition outlined in obs_data\n",
+ " \"\"\"\n",
+ " \n",
+ " def guide(noise):\n",
+ " \"\"\"\n",
+ " The guide serves as an approximation to the posterior p(z|x). \n",
+ " The guide provides a valid joint probability density over all the \n",
+ " latent random variables in the model.\n",
+ " \n",
+ " https://pyro.ai/examples/svi_part_i.html\n",
+ " \"\"\"\n",
+ " # create params with constraints\n",
+ " mu = {\n",
+ " 'N_X': pyro.param('N_X_mu', 0.5*torch.ones(self.image_dim),constraint = constraints.interval(0., 1.)),\n",
+ " 'N_Z': pyro.param('N_Z_mu', torch.zeros(self.z_dim),constraint = constraints.interval(-3., 3.)),\n",
+ " 'N_Y_1': pyro.param('N_Y_1_mu', 0.5*torch.ones(self.label_dims[1]),constraint = constraints.interval(0., 1.)),\n",
+ " 'N_Y_2': pyro.param('N_Y_2_mu', 0.5*torch.ones(self.label_dims[2]),constraint = constraints.interval(0., 1.)),\n",
+ " 'N_Y_3': pyro.param('N_Y_3_mu', 0.5*torch.ones(self.label_dims[3]),constraint = constraints.interval(0., 1.)),\n",
+ " 'N_Y_4': pyro.param('N_Y_4_mu', 0.5*torch.ones(self.label_dims[4]),constraint = constraints.interval(0., 1.)),\n",
+ " 'N_Y_5': pyro.param('N_Y_5_mu', 0.5*torch.ones(self.label_dims[5]),constraint = constraints.interval(0., 1.))\n",
+ " }\n",
+ " sigma = {\n",
+ " 'N_X': pyro.param('N_X_sigma', 0.1*torch.ones(self.image_dim),constraint = constraints.interval(0.0001, 0.5)),\n",
+ " 'N_Z': pyro.param('N_Z_sigma', torch.ones(self.z_dim),constraint = constraints.interval(0.0001, 3.)),\n",
+ " 'N_Y_1': pyro.param('N_Y_1_sigma', 0.1*torch.ones(self.label_dims[1]),constraint = constraints.interval(0.0001, 0.5)),\n",
+ " 'N_Y_2': pyro.param('N_Y_2_sigma', 0.1*torch.ones(self.label_dims[2]),constraint = constraints.interval(0.0001, 0.5)),\n",
+ " 'N_Y_3': pyro.param('N_Y_3_sigma', 0.1*torch.ones(self.label_dims[3]),constraint = constraints.interval(0.0001, 0.5)),\n",
+ " 'N_Y_4': pyro.param('N_Y_4_sigma', 0.1*torch.ones(self.label_dims[4]),constraint = constraints.interval(0.0001, 0.5)),\n",
+ " 'N_Y_5': pyro.param('N_Y_5_sigma', 0.1*torch.ones(self.label_dims[5]),constraint = constraints.interval(0.0001, 0.5))\n",
+ " }\n",
+ " for noise_term in noise.keys():\n",
+ " pyro.sample(noise_term, dist.Normal(mu[noise_term], sigma[noise_term]).to_event(1))\n",
+ " \n",
+ " # Condition the model\n",
+ " if intervened_model is not None:\n",
+ " obs_model = pyro.condition(intervened_model, obs_data)\n",
+ " else:\n",
+ " obs_model = pyro.condition(self.model, obs_data)\n",
+ " \n",
+ " pyro.clear_param_store()\n",
+ "\n",
+ " # Once we’ve specified a guide, we’re ready to proceed to inference. \n",
+ " # Now, this an optimization problem where each iteration of training takes \n",
+ " # a step that moves the guide closer to the exact posterior \n",
+ " \n",
+ " # https://arxiv.org/pdf/1601.00670.pdf\n",
+ " svi = SVI(\n",
+ " model= obs_model,\n",
+ " guide= guide,\n",
+ " optim= SGD({\"lr\": 1e-5, 'momentum': 0.1}),\n",
+ " loss=Trace_ELBO(retain_graph=True)\n",
+ " )\n",
+ " \n",
+ " num_steps = 1500\n",
+ " samples = defaultdict(list)\n",
+ " for t in range(num_steps):\n",
+ " loss = svi.step(self.init_noise)\n",
+ "# if t % 100 == 0:\n",
+ "# print(\"step %d: loss of %.2f\" % (t, loss))\n",
+ " for noise in self.init_noise.keys():\n",
+ " mu = '{}_mu'.format(noise)\n",
+ " sigma = '{}_sigma'.format(noise)\n",
+ " samples[mu].append(pyro.param(mu).detach().numpy())\n",
+ " samples[sigma].append(pyro.param(sigma).detach().numpy())\n",
+ " means = {k: torch.tensor(np.array(v).mean(axis=0)) for k, v in samples.items()}\n",
+ " \n",
+ " # update the inferred noise\n",
+ " updated_noise = {\n",
+ " 'N_X' : dist.Normal(means['N_X_mu'], means['N_X_sigma']),\n",
+ " 'N_Z' : dist.Normal(means['N_Z_mu'], means['N_Z_sigma']),\n",
+ " 'N_Y_1': dist.Normal(means['N_Y_1_mu'], means['N_Y_1_sigma']),\n",
+ " 'N_Y_2': dist.Normal(means['N_Y_2_mu'], means['N_Y_2_sigma']),\n",
+ " 'N_Y_3': dist.Normal(means['N_Y_3_mu'], means['N_Y_3_sigma']),\n",
+ " 'N_Y_4': dist.Normal(means['N_Y_4_mu'], means['N_Y_4_sigma']),\n",
+ " 'N_Y_5': dist.Normal(means['N_Y_5_mu'], means['N_Y_5_sigma']),\n",
+ " }\n",
+ " return updated_noise\n",
+ " \n",
+ " def __call__(self):\n",
+ " return self.model(self.init_noise)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lk7vAe8byP5G",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Sanity check: 1\n",
+ "### Making sure VAE works"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "TDI4enwfEPCY",
+ "colab_type": "code",
+ "outputId": "921d5efc-4d33-4d7f-e167-a60e70e848f8",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 303
+ }
+ },
+ "source": [
+ "# Generate an instance of dSprites image \n",
+ "ox, y = get_specific_data(cuda=True)\n",
+ "plot_image(ox)\n",
+ "# Pass it through VAE to get q(z|x) => N(mu, sigma)\n",
+ "mu, sigma = vae.encoder.forward(ox,vae.remap_y(y))\n",
+ "# Feed these params to our custom SCM\n",
+ "scm = SCM(vae, mu.cpu(), sigma.cpu())\n",
+ "print(y)\n",
+ "# Check for reconstruction\n"
+ ],
+ "execution_count": 392,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "tensor([[ 0., 2., 1., 19., 13., 10.]], device='cuda:0')\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ0AAAENCAYAAAAVEjAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAA8tJREFUeJzt28tqwzAQQFGr5P9/Wd2VUELjm0eF\n43NWpSFFq8uMao855waw19fqAwDHIhpAIhpAIhpAIhpAIhpAIhpAIhpAIhpAIhpAIhpAIhpAcll9\ngEeMMbxlB2825xy3fm/SABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLR\nABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLR\nABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLR\nABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRABLRAJLL6gNw\nbHPOn5/HGAtPwn8xaQCJaACJ9YTkeh2595l15TOZNIBENIDEesJdf60ke79nVfkcJg0gEQ0gEQ0g\ncafBTY/eY+z5e+43js2kASSiASTWE7Zte/06wucyaQCJaACJaACJO40TW3WP4W3YYzNpAIloAIn1\nhOU8LXosJg0gEQ0gsZ6ciKc+eQWTBpCIBpCIBpC40ziR3//OdMfBI0waQCIaQGI94SnXK4915xxM\nGkAiGkAiGkDiTuPE9t5HvPvNU2+2HotJA0hEA0isJ2zb9poVwZpxDiYNIBENIBENIBENIBENIBEN\nIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBEN\nIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBEN\nIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBEN\nIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBEN\nIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBENIBEN\nIBENIBENIBENIBENIBENIBENIBENIBlzztVnAA7EpAEkogEkogEkogEkogEkogEkogEkogEkogEk\nogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEkogEk\nogEkogEkogEkogEkogEkogEkogEkogEkogEkogEk30cpOSeIUv9IAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B1rkIEyL-357",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Sanity check 2\n",
+ "\n",
+ "### To check if the decoder is able to generate the image if the latents are changed:\n",
+ "#### To achieve this we manually change the labels in the code and run it through the decoder and check for reconstruction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1CcO36iqEf_T",
+ "colab_type": "code",
+ "outputId": "8dc36b9f-0ee7-44e2-ba86-8299afa4c6f4",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 492
+ }
+ },
+ "source": [
+ "original, y_original = get_specific_data(cuda=True)\n",
+ "print('top: ',y_original)\n",
+ "mu, sigma = vae.encoder.forward(original,vae.remap_y(y_original))\n",
+ "B = 100\n",
+ "zs = torch.cat([dist.Normal(mu.cpu(), sigma.cpu()).sample() for a in range(B)], 0)\n",
+ "ys = torch.cat([vae.remap_y(y_original) for a in range(B)], 0)\n",
+ "rs = vae.decoder.forward(zs.cuda(), ys).detach()\n",
+ "compare_to_density(original,rs)\n",
+ "\n",
+ "y_new = torch.tensor(y_original)\n",
+ "y_new[0,1] = (y_original[0,1] + 1) % 2\n",
+ "print('bottom: ', y_new)\n",
+ "zs = torch.cat([dist.Normal(mu.cpu(), sigma.cpu()).sample() for a in range(B)], 0)\n",
+ "ys = torch.cat([vae.remap_y(y_new) for a in range(B)], 0)\n",
+ "rs = vae.decoder.forward(zs.cuda(), ys).detach()\n",
+ "compare_to_density(original,rs)"
+ ],
+ "execution_count": 393,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "top: tensor([[ 0., 0., 2., 16., 31., 5.]], device='cuda:0')\n",
+ "bottom: tensor([[ 0., 1., 2., 16., 31., 5.]], device='cuda:0')\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAADcCAYAAACBHI1wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAD+9JREFUeJzt3X2wXVV9xvHnyRt5QaT1gpEKRhph\naB2qkqmMVMmIZUClQ6faaWmnBNo/zPQNR2dqddrGOi12Jq3tWJFioNeZhmqhxDpYUwRJMVQZSbW2\najAgKaEJEBGB6A03L6t/rHWSnZ197u/c83LPuTffz8yZc+9ea++9zj7rnGfvtfe+1yklAQAwlXnD\nbgAAYPQRFgCAEGEBAAgRFgCAEGEBAAgRFgCAEGEx4mxvsZ1sr+7T8naW5a3ox/JmaxuAmWJ7vPT3\nNcNuSy8ICwAjxfaa8uU6Puy2RGyvKG3dOey2DNqCYTcAod+QtFTSY31a3iWSFkr6vz4tD8DU/lDS\nhyXtGXZDekFYjLiUUr9CorW8R/q5PABTSynt0SwPColhqL6z/UrbN5Vx+RdsP23732y/vaHukfMR\ntt9i+y7b3y/TXlOv0zD/y2xvsL3H9n7b37b9B7bntzsv0Ml022+1/SXbz9t+zvZm269r83p/qYzJ\nfsv2s7YnbG+3vd72WNcbcg4r2zk5W2t7m+19tn9Qq3ey7ffb/s/yXvzI9tdtv9f2oimW/zbbn7X9\nhO1J27tt32v79xrqLrJ9ne0HK+v4hu0/sn1yQ/0jQ0S2X2z7b2zvKn39Edt/Yvu4nVDbS2z/ru2v\n2t5b+utu2/fZfn+l3hZJf19+vbqyrY4Zlqr111+2vbX0v2T71Op2brONphw+sn267T8v22Jf2Tbb\nbd9o+9WlzjpJj5ZZXlFr687Kstqesxi17T+llBKPPj0kvUHSs5KSpO9I+kdJ90o6WKZdX6u/pUy/\nUdJhSV+TdKukL0k6v1ZndW3elysPTSXlIaVPS/q8pAlJt0vaWcpW1OaLpl8v6VBpwz9JerhM3yfp\nnIbXfFDSc5K+Uur/q6QnyzyPSjqtYZ7GNpwoj/Lak6QbJB2Q9MXSV+6v1DlT0vZSb4+kz0m6U9L3\nyrR7JS2qLdeSPlHKD0n6culPd0t6In/cj6m/RNJ9pf5zkv6l9J3WOr4haaw2z5pS9hlJ3yrLvU3S\nFyTtL2U31eaZV9qbJD1TXsetZdqTkvZX6r5P0tZS92FJ45XHbzX0oRvK83+UZT4o6cXV7dzmPVhR\nync2lF1Q6cNPltd6m6RtZbuuK/WuLNur9fmotnV9ZXnjpc6aUd/+U/bbYX9w5spD0mJJu8ob8meS\nXCl7g6TnS9nllelbdPSLY02b5bbqrK5N/2yZfrukxZXp50jaXVnuitp8O4PpE5IurkxfKGlTKbul\noX3vlLSkYVtsKPPc2DBPYxtOlEflvfm+pNc2lFs5fJOk9ZJOqpSdKmlzKfvT2nzvKdMfqy9X0nxJ\nV9SmrS/1vy7p9Mr0U5QDLEn6dG2e1pdVknRHre+9Xnnn4XD1vZV0can/oKRlDe16c5t1jE+xDVt9\naFLSpVNt5zZlK9QQFpJepLzzlST9pY4P5DMlXRAtpzbPuJrDYiS3f9vXMewPzlx5KJ+ITsp7g/Ma\nyteV8rsr07aUaZunWG6rzupaBz2svCexvGGetZUOtaJWtjOY/uGG5a0qZY9OY3ssUd5r3ttQ1tiG\nE+VReW/e16b8raV8iyo7HZXyl0l6QXkP1GXaQh3dI31Th+/PvlL/oobyleWL55CksyrTW19Wz6n5\nqPHOUn51Zdo7y7S/7nD7tNYxPkWdVh86bmekvp3blK1Qc1i8u0y/p8O2Ni6nVmdctbAY5e3f7sE5\ni/55U3n+h5TS4YbyW8rzRbbn18o+M811vVF57/O+lNITDeW3TnN5VZ9vmPZQeT6jaQbb55Vx14/a\nvqWMLX9cea9vzPaP9dCeuazd+355eb49lU97VconTHdIeomkV5XJq8rvD6eU7utg3RdIWibpkZTS\n/Q3reFh5iGSecn+r25ZS2tswvamvfE35S+9a2++yfXoH7evUdD87kcvK8y1T1urdrNv+hEX//ER5\nfrRN+ePKX56LlT/UVf/b5boa50spPat87qQbuxqW93z58ZiTqrYX2L5Zeez0I5J+R9I1kq4uj6Wl\n6ildtmWua/e+n12eP1o7aXrkIemnS53TyvNZ5fkhdSbqr5L03VrdquP6SdHqKye1JpQvvt9XPvr5\nuKQnbe8oOxZvs+0O29xkup+dyHS3Y7dm3fbn0tnRMNHlfMftdVY0Hd10YjrzXSfpWuUx3ncrn1B9\nKqU0KUm2dysPmfTyZTBnpZTave+tI88vqv2XQsvTrcX1pVGdm1b/Sil9zPY/S3q78r0+b1TesbhG\n0j22L0spHeyiHd1+dtrtKM/0duzWjG9/wqJ/Wje5nd2m/OXKe+b7lU9s9mJ3eT6rqdD2KZJmYujn\nHeX5XSmlO2ttWCZp+Qy0YS5qBcStKaWbO5yndT/OOR3Wj/prtawvN3CWIdMN5SHbr1e+CuwSSb8p\n6e/6sZ6KA5IW2j45pbSvVnZmm3kek3Se8nZ8sM/tqZp1259hqP5pjRP/mu2m7XpNeb6/yz2oqq3l\n+WLbL20o/9Uel9+pHy/PTXu/vyKOKLq1uTy/Y8pax9qmfJTxKts/12H9H0o62/ZF9ULbP6m893lY\n+TLqvkspPaDyxSXp/ErRZHnudWe2tVN1bkPZpW3muas8X9vhOrpt6yhv/0aERf/cprwHcK6kD1bH\nAUuCv6f8+le9riil9F3l+xkWK49rHxmftL1S0h/3uo4ObS/Pa2uv9zXK92ugO5uUT0peZvsj5Ujx\nGOWmsl9v/Z5SOqD8JyUkaaPt82v159u+olJ/Qkf3JP/W9mmVui8qZQuUT7L39FcEbL/Z9uX1m8Wc\nbyz8+fJrdR2tPenzelmv8n0EkvSB6rptX6o8bNpkg/J9LZfY/gvXbn60fabtCyqT9ioHxkuncyHH\niG//to3m0aeHpIt09Ka87cpXJd2j+Ka81VMss7GO8hDU46XsceWb8j6nPIZ7h/KJvyTpjNp8OzWN\nS2or5cddhqh8/8hk5fV+Snmc/WB57V2ta64/mrZlQ52zJH2z1P2BpH+XtFH5xq3vlOlfqc1jHb1M\n85Ck+8v78AXFN+U9q6M3n+0t0/5b7W8KG2/T7nWlfF1l2nU6ekPY3ZXX8VSZ/pCkUyv1T1L+wm7d\nG/BJ5S/xa6bTh5R33FqXp+4or+2rynvr16v9TXk/W9kGT5TP03E35VXq39FaVnltG1S5BF2d3ZQ3\nMtu/7fYc9gdnrj2Uxxk/ofxlPal8fuIuSb/QUHeLugyLUnaGpJtLh95f3vQPlA/bC6VjL67N0/gh\niz58avMFJ+l1ypfbPqV8WP1fpXPO63Zdc/3Rbls21FuifBXL1vJBn1Te6/6ypA+p3OXfMN8vKg9l\nfa8yzz2Sfruh7iLlvextyl+sE5L+R/no9OSG+t18Wa2U9EHlPf1dpa8+pfzF/V5JpzQs52eU7xl4\nuvTjY9bZaR8q/XOz8n0JP1S+0/sKBfdHKF+YsV55J2iizP9tSR+T9FO1ui9RDohdyudJjlmu2oTF\nKG//pkfrhh7MIWUMdKukb6aUXj3s9gCY/ThnMUuVexxe2zD9XEk3lV8/ObOtAjBXcWQxS5W/SPm8\n8uH4duXD5Fco3xm6QHks9C0pn/gEgJ4QFrNUubLhQ8rXSL9S+Q/M/Uh5XPVTkm5I5eY4AOgVYQEA\nCHHOAgAQmpV/7sNt/vsV0C8ppaHcfU7fxqB127c5sgAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECI\nsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAA\nhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLzJiUklJKw24G0DdjY2MaGxvT\n5OSkJicnh92cgSIsAAAhz8Y9Pduzr9EnsKn6mO0ZbEnnUkpDaRh9e3aZqm9feOGFkqQHHnhgpprT\nkW77NkcWAIAQRxYYiF761SgcbXBkgXYOHz585Ofp9tWlS5dKkiYmJvrapungyAIAMDCEBQAgxDAU\nBqKf/WoYw1IMQ6GdfvTt1jLmzZv5/XWGoQAAA7Ng2A3A3DKII9XqMkfh5DdOTIM4Wq4uc9GiRZKk\nAwcO9G09/cSRBQAgRFgAAEIMQ6FnM3mRRH1dDEthkJ555pkZW1f9b0uNWt/myAIAEOLSWfRsVPpQ\nP/fEuHQW0uj07dYltn26bJdLZwEAg8E5C3RtVPa6WlrtGbWxXsw+o9a3d+zYIUlauXLl0NrAkQUA\nIERYAABCnODGtI1qn+EEN3q1Z88eSdLy5cuH3JJjjULf5sgCABDiyALTNqp9ZhT2vnpF3555ixcv\nPvLzMP8pUd3BgweP/Lxw4cK+LZcjCwDAwBAWAIAQw1DoyKj2k0HdU8Ew1Imjl/+pPUij1rc5sgAA\nhLiDG7PSKO0BYnYa1T60atWqYTehEUcWAIAQ5yzQtWH2nUHvFXLO4sS0ceNGSdJVV101tDaMat/m\nyAIAECIsAAAhhqHQs5nsQzN1UpJhKEj07SqOLAAAIS6dRc+a9ohm4xErUFft261/bXro0KG+Lb96\nQ+Co48gCABAiLAAAIYahMBDVw/deh6RG9U5bnFhaQ0bV/njllVdKkjZt2tTVMufPn997w2YIRxYA\ngBCXzmLGddrnhnlEwaWz6EanfXvJkiWSpP379w+yOY24dBYAMDCcs8CM41JbzFXVvr1v3z5J0rJl\ny45MW7t2raThHFH0iiMLAECIsAAAhDjBDTTgBDfmKk5wAwAGhrAAAIQICwBAiLAAAIQICwBAiLAA\nAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQI\nCwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBA\niLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAA\nAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQI\nCwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBA\niLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIScUhp2GwAA\nI44jCwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQI\nCwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBA\niLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAAAIQICwBAiLAA\nAIT+H7Ibd/Pt/SDNAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAADcCAYAAACBHI1wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAD0tJREFUeJzt3X2sXFW9xvHnaQu0BbFeAZUrWBEh\noEEtRhNRaUQJqBgNenNvrpGC/iHxDaOJb1GrRtGkvtx4BcTCPSZXfAHxBdReBKlYfIkU320RkEqx\nQCsqUD2llK77x1rT7u7uOb85c+ac2Wf6/SSToXutvWfNnjXz7LX23genlAQAwETmDLsBAID2IywA\nACHCAgAQIiwAACHCAgAQIiwAACHCouVsr7adbC8d0PY2lO0tHsT2ZmsbgJlie6z092XDbstUEBYA\nWsX2svLjOjbstkRsLy5t3TDstky3ecNuAEKvk7RQ0p0D2t4pkvaT9OcBbQ/AxN4j6eOS7h52Q6aC\nsGi5lNKgQqKzvdsHuT0AE0sp3a1ZHhQS01ADZ/vJti8u8/IP2b7P9v/ZfnlD3V3nI2y/2PY1tv9a\nlj2zXqdh/SfYXmn7btvbbK+z/S7bc7udF+hlue2X2v6R7QdtP2B7le0lXd7vmWVO9ve277c9bnu9\n7RW2D+l7R46wsp+Ts3Ntr7W91fbfa/UOsv1e2zeXz+Kftn9p+522959g+y+z/W3b99jebnuT7ett\nv7Wh7v62z7N9U+U1fm37/bYPaqi/a4rI9qNt/5ftjaWv3277g7b3Ogi1vcD2W2z/3PaW0l832b7B\n9nsr9VZL+p/yz7Mq+2qPaalaf/0322tK/0u2F1X3c5d9NOH0ke3DbH+s7IutZd+st32R7aeXOssl\n3VFWeVKtrRsq2+p6zqJt+39CKSUeA3pIep6k+yUlSX+Q9GVJ10vaUZadX6u/uiy/SNJOSb+QdJmk\nH0k6oVZnaW3dJypPTSXlKaWvSvqepHFJV0jaUMoW19aLlp8v6ZHShq9Juq0s3yrpmIb3vEPSA5J+\nWup/V9K9ZZ07JB3asE5jG/aVR3nvSdIFkh6W9IPSV26s1DlC0vpS725J35F0taS/lGXXS9q/tl1L\n+kIpf0TST0p/ulbSPfnrvkf9BZJuKPUfkPSt0nc6r/FrSYfU1llWyr4p6fdlu5dL+r6kbaXs4to6\nc0p7k6S/lfdxWVl2r6RtlbrvlrSm1L1N0ljl8YaGPnRBef5x2eZNkh5d3c9dPoPFpXxDQ9mJlT58\nb3mvl0taW/br8lLvlWV/db4f1bauqGxvrNRZ1vb9P2G/HfYXZ1QekuZL2lg+kI9KcqXseZIeLGWn\nV5av1u4fjmVdttups7S2/Ntl+RWS5leWHyNpU2W7i2vrbQiWj0s6ubJ8P0nfKGWXNrTvNZIWNOyL\nlWWdixrWaWzDvvKofDZ/lfSshnIrh2+StELSAZWyRZJWlbIP19Z7R1l+Z327kuZKOqO2bEWp/0tJ\nh1WWH6wcYEnSV2vrdH6skqQra33vucoHDzurn62kk0v9myQd2NCuF3V5jbEJ9mGnD22XdOpE+7lL\n2WI1hIWkRykffCVJn9TegXyEpBOj7dTWGVNzWLRy/3d9H8P+4ozKQ/lEdFI+GpzTUL68lF9bWba6\nLFs1wXY7dZbWOuhO5SOJxzesc26lQy2ulW0Iln+8YXvPLmV3TGJ/LFA+at7SUNbYhn3lUfls3t2l\n/KWlfLUqBx2V8idIekj5CNRl2X7afUT6wh4/n62l/kkN5UeXH55HJB1ZWd75sXpAzaPGq0v5WZVl\nrynLPtPj/um8xtgEdTp9aK+Dkfp+7lK2WM1h8fay/Loe29q4nVqdMdXCos37v9uDcxaD88Ly/L8p\npZ0N5ZeW55Nsz62VfXOSr/UC5aPPG1JK9zSUXzbJ7VV9r2HZLeX58KYVbB9X5l0/a/vSMrd8ofJR\n3yG2HzOF9oyybp/76eX5ilS+7VUpnzC9VdJjJT21LH52+fdtKaUbenjtEyUdKOn2lNKNDa9xm/IU\nyRzl/la3NqW0pWF5U1/5hfKP3jm232j7sB7a16vJfncip5XnSyesNXWzbv8TFoPzr+X5ji7ldyn/\neM5X/lJX/anP12pcL6V0v/K5k35sbNjeg+U/9zipanue7UuU504/LenNks6WdFZ5LCxVD+6zLaOu\n2+d+VHn+bO2k6a6HpKeVOoeW5yPL8y3qTdRfJemPtbpVe/WTotNXDugsKD98b1Me/Vwo6V7bt5YD\ni5fZdo9tbjLZ705ksvuxX7Nu/3PpbDuM97neXkedFU2jm15MZr3zJJ2jPMf7duUTqptTStslyfYm\n5SmTqfwYjKyUUrfPvTPy/IG6/yh03NfZ3EAa1btJ9a+U0udsf13Sy5Xv9XmB8oHF2ZKus31aSmlH\nH+3o97vT7UB5pvdjv2Z8/xMWg9O5ye2oLuVPVD4y36Z8YnMqNpXnI5sKbR8saSamfl5dnt+YUrq6\n1oYDJT1+BtowijoBcVlK6ZIe1+ncj3NMj/Wj/lotG8gNnGXKdGV5yPZzla8CO0XS6yV9fhCvU/Gw\npP1sH5RS2lorO6LLOndKOk55P9404PZUzbr9zzTU4HTmif/TdtN+Pbs839jnEVTVmvJ8su3HNZT/\nxxS336t/Kc9NR7//LkYU/VpVnl89Ya09rVUeZTzV9vN7rP8PSUfZPqleaPspykefO5Uvox64lNLP\nVH64JJ1QKdpenqd6MNs5qDq2oezULutcU57P6fE1+m1rm/d/I8JicC5XPgI4VtKHqvOAJcHfUf75\nqam+UErpj8r3M8xXntfeNT9p+2hJH5jqa/RofXk+t/Z+n6l8vwb68w3lk5Kn2f50GSnuodxU9trO\nv1NKDyv/SQlJ+pLtE2r159o+o1J/XLuPJP/b9qGVuo8qZfOUT7JP6a8I2H6R7dPrN4s531j4kvLP\n6mt0jqSPm8rrKt9HIEnvq7627VOVp02brFS+r+UU259w7eZH20fYPrGyaItyYDxuMhdytHz/d200\njwE9JJ2k3TflrVe+Kuk6xTflLZ1gm411lKeg7ipldynflPcd5TncK5VP/CVJh9fW26BJXFJbKd/r\nMkTl+0e2V97vV5Tn2XeU997Xa436o2lfNtQ5UtLvSt2/S/qhpC8p37j1h7L8p7V1rN2XaT4i6cby\nOXxf8U1592v3zWdbyrLfqPtNYWNd2r28lC+vLDtPu28Iu7byPjaX5bdIWlSpf4DyD3bn3oAvKv+I\nnz2ZPqR84Na5PPXW8t5+rny0fr6635T3nMo+uKd8n/a6Ka9S/8rOtsp7W6nKJejq7aa81uz/rvtz\n2F+cUXsozzN+QfnHervy+YlrJL2ioe5q9RkWpexwSZeUDr2tfOjvK1+2h0rHnl9bp/FLFn351OUH\nTtIS5cttNysPq39VOuecfl9r1B/d9mVDvQXKV7GsKV/07cpH3T+R9BGVu/wb1nuV8lTWXyrrXCfp\nTQ1191c+yl6r/MM6Lum3yqPTgxrq9/NjdbSkDykf6W8sfXWz8g/3OyUd3LCdZyjfM3Bf6cd7vGav\nfaj0z1XK9yX8Q/lO7zMU3B+hfGHGCuWDoPGy/jpJn5N0fK3uY5UDYqPyeZI9tqsuYdHm/d/06NzQ\ngxFS5kDXSPpdSunpw24PgNmPcxazVLnH4VkNy4+VdHH55xdntlUARhUji1mq/EXKB5WH4+uVh8lP\nUr4zdJ7yXOiLUz7xCQBTQljMUuXKho8oXyP9ZOU/MPdP5XnVr0i6IJWb4wBgqggLAECIcxYAgNCs\n/HMf7vJ/vwIGJaU0lLvP6duYbv32bUYWAIAQYQEACBEWAIAQYQEACBEWAIAQYQEACBEWAIAQYQEA\nCBEWAIAQYQEACBEWAIAQYQEACBEWAIAQYQEACBEWAIAQYQEACBEWAIAQYQEACBEWAIAQYQEACBEW\nAIAQYQEACBEWAIAQYQEACBEWAIAQYQEACBEWAIAQYYEZk1JSSmnYzQDQB8ICABCaN+wGYPTVRxPV\nf9ue6eYA6AMjCwBAiJEFpkWv5yaa6jHaANqHkQUAIERYAABCTEOhdepTU0xLYaYN4hLvUeu3jCwA\nACFGFhio6bjpjkttMVN27NgxsG2N2sUbjCwAACHCAgAQYhoKUzaTf++Jk9+YTnPnzp3W7Xf67/j4\nuCRp4cKF0/p6g8TIAgAQ8mz8K6C2Z1+jR1hb+tAgRxkppaEMWejbwzXMvjxTo+R++zYjCwBAiHMW\n6FtbRhQdnfZwHgOzUdv7LyMLAECIsAAAhJiGwqS1bfqpo63Dd8weS5YskSTdfPPNQ2tDW6ejGFkA\nAEJcOotJa2uf4dJZDEob+vjxxx+/67/XrVs3sO1y6SwAYNoQFgCAENNQ6Elb+8l0nQRkGgpSe/p9\nG6ZYGVkAAEJcOotZqW2XFWI0VftZW0YZw8LIAgAQYmSBnjQdye/rR1rYt3S+A/tqv2dkAQAIERYA\ngBDTUOjbMIblnNjGsFX74M6dO/daNqoYWQAAQowsMGWc/Ma+as6cPY+3B9nv2zZaYWQBAAgRFgCA\nENNQmBaDvPO1bcNxoJuory5atEiStHnzZknSVVddtavszDPPnL6GDQAjCwBAiL86ixnXa58b5oiC\nvzqLUcVfnQUATBvOWWDGcaktMPswsgAAhAgLAECIaSi0ApfHAu3GyAIAECIsAAAhwgIAECIsAAAh\nwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIA\nECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIs\nAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAh\nwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIA\nECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIs\nAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAhwgIAECIsAAAh\np5SG3QYAQMsxsgAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAA\nhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgL\nAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECIsAAAhAgLAECI\nsAAAhAgLAEDo/wF7mAhYszWNWQAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bwyk5S5a7knN",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Conditioning the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "tQu4lzTMp8QH",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ },
+ "outputId": "6a85b4b4-9af1-4097-eb7f-6adc13f1431d"
+ },
+ "source": [
+ "cond_data = {}\n",
+ "for i in range(1, 6):\n",
+ " cond_data[\"Y_{}\".format(i)] = torch.tensor(y[0,i].cpu()).to(torch.float32)\n",
+ "print(cond_data)"
+ ],
+ "execution_count": 394,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "{'Y_1': tensor(2.), 'Y_2': tensor(1.), 'Y_3': tensor(19.), 'Y_4': tensor(13.), 'Y_5': tensor(10.)}\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "rtJRtQgJve2A",
+ "colab_type": "code",
+ "outputId": "93aa7f5c-97bc-434b-c564-18ffebbe3b5a",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ }
+ },
+ "source": [
+ "# cond_data['Y_1'] = torch.tensor(1.)\n",
+ "# cond_data['Y_2'] = torch.tensor(4.)\n",
+ "conditioned_model = pyro.condition(scm.model, data=cond_data)\n",
+ "cond_noise = scm.update_noise_svi(cond_data)\n",
+ "print(cond_data)"
+ ],
+ "execution_count": 395,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "{'Y_1': tensor(2.), 'Y_2': tensor(1.), 'Y_3': tensor(19.), 'Y_4': tensor(13.), 'Y_5': tensor(10.)}\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "0r2HePhs9DnB",
+ "colab_type": "code",
+ "outputId": "a95abb00-a531-4b8c-ca29-501c340e38d2",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 277
+ }
+ },
+ "source": [
+ "rxs = []\n",
+ "for i in range(100):\n",
+ " (rx,ry,_), _ = scm.model(cond_noise)\n",
+ " rxs.append(rx)\n",
+ "compare_to_density(ox, torch.cat(rxs))\n",
+ "_ =plt.suptitle(\"SCM Conditioned on Original\", fontsize=18, fontstyle='italic')\n"
+ ],
+ "execution_count": 396,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEDCAYAAADEAyg+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJztnXmYX0WV9z+nSUgIEED2JRBZVBCB\ngK846oijDuPKuD/46Ku4DOMy7uO4jYorrqPjKC4g4gIuIKjjriAibq+yKLJDiATZAoSwhQRIvX9U\n3e7q2/d3q+r+bqd/Hb6f5/k9t7vq3Kq6W1WdU6eqzDmHEEII0cbYTBdACCHE6KPGQgghRBI1FkII\nIZKosRBCCJFEjYUQQogkaiyEEEIkUWMhRhYzGzOzO83s61HYw8zMmdkbM9Mokh8FzGz7UOajZ7os\n6wszu8nMvjfE+R8P9+wBfZZrQF77hrzeNN15jRJqLHrAzLYys/eZ2flmdoeZrTazq83sx2Z25IBz\nDjWzr5nZVUH+DjM7z8zeb2bbR3ILzWxdeDn/34C0tjCzG4PMKjOzzHKbmT3XzE41s+VmdreZ3WZm\nvzOzt5rZwm53pDf2BBYA50VhS8Lxj1VAuP9HmdkhDWlMkZ8F7B+O57VKzSBm9kQz+3p4z9eY2S1m\n9lMze3qHtHYFtma4Z7QEuMo5d8sQaZTkBSP8fKaDOTNdgNmOmT0U+BmwJXAS8EXAAXsB/wz8E/CF\nSH4B8FXgWcCVwInAMmBz4PHA24FDgUeEUw4EDFgN7GNm5qbOpDwKqCr28xrim8q9LXAq8BjgT8Cx\nwN/wH+1TgKND3s/LuQ/TxH7heH4UdiLwLWBNFPZo4N3AXxrSaJIfdZqueyQws/nAl4DDgQuB4/Hv\nzW7Ai4HvmdkHnXPvKEh2ObAJcM8QRXvSEOeWUjUWI/d8phXnnH4df/hK/EJgFbBvQ/wYsGP0/xzg\nF/jG5L3AWMM5DwfeGf3/xiB/YjjuUZN/CLAW+EaI/3hGuReGcq8FXjFA5p+AV83w/X1fuKZtE3Lv\nD3K7zPQ70dN1fxW4ven9mOFyzcV3jO4D3twQvyVwUXgWh2SkN3+mr6njfTgDWDbT5Vjv1z3TBZjN\nP7y5wAEnZ8q/O7dCj86pGoknhOMzavE/ApYC/xriX5CR5peC7GsKr/fhwMnADfie+l+AFzXInQLc\nAuwAfBbf87wTOAvYp0F+W+AzwLVB7qd4E9R3gb/VZK8H/jf8vWW4jqbfnnX5WjrPA84EbgXuChXA\nIxvkbgJOw2tZ3wVWht+JwBYN8rsBxwBXAXcDVwAfBDZpkH1gaBhuCo3DKcB2eE3v15nPZGFI/9KQ\n3/Xhni+syT0t3JdnAq/Am1BWh3IemZnXB0Ma72uReU6QOT4K2xTfwHwGOAw4O9zzn4T4b+I7XFZL\n61Dgl+GduAH4CLCoXgbgySHs2V2vF3gUXku6FLgDuDm8E49tkL0FOK1LnTGbfzNegNn8w5s/HHAu\nsCAhu3V4Ca8HNi7I42Lgr8DGeE3gP6O46oN4FvCp8PdDEuntA6wDzi281iPwZoJz8NrOq8NH7+of\nH968tgy4OnyArwA+CtxbzxfYKciuxJu+jsSbx5YC1wDfr8k64D3h/62AF+Ir/PPC3y8EXoDX+ibJ\nh3MMb3Jz+EbgFcB/hLKuBvaLZKuK6Q+hsvpwKN+3Q/h/1a7l4FCWv+JNg0eGvO4FTqrJPixUSNcA\n7wrl+AXwu/CcP53xTHYN92klviJ9Ob5Cvjc8G4tk3xnK/JuQx5uB14bz1wEHJPLaGd9B+Cswp0Vu\nj5DP76KwR4WwPwE3Ah8A/oVQuQOXAb+spfPyUK4/Aq8J79zSUH4HPCuSfXsIe2DX6wW+B/wgPIuX\n4bXaFfhvNk53t5Duu2a6/lnfvxkvwGz+4e2sV4aXZyXw9fARLGqQfV2Qe39B+pvhe2Snhf/PB74Z\n/p6L7wWdHv4/mwzTBfCJUI6XF5TjUaEC+lZcUQDzQuVxdRS2RfgY7wEeX0vnayFuXhT2K3ylWTev\nnV6/X8BTqWlX+AbDAR9tKHeT/H+EsNfWZPcO4V+Jwg4LYcuBnaLwMXzj8usobGt8RXgGtY4DvpFx\nVRrAfLzGcQWwTSQ3JzzT5PMJz7+qfPesxX0opPH4KOzUEHZs/I4ATwzh/5LI7z+D3FsTcjsGud9H\nYa8OYUuBHWrym4d34pNR2H74hukUYG4Uvgf+e6g3DCcDK2vpFl0vsGnDtTy+/iyAZ4Sww3K/nw3l\nJ2+oIXDOrQb+Dt+rW4kf9PsC8Fcz+6GZ7RyJPzEcv12QxQH4iunc8P/5+B4p+MZnD+D1ZjaGN4md\n75xbl0izKsepBeV4H97E8a/OuXurQOfcGuC3wCIz2yQqswGfd86dUUtnDb7RuQfAzA7DD7B/yDl3\nZU32Z+HY5AkVhx0YjucylUnyZrYp8A7gN865T8WCzrmL8drDgxvOf4Nz7tpItmoMV0eyb8ebxd4I\nLDCzbaofEwPve4Xjkfhn9zbn3E1RuvfitYv6NTZxOL5SfYdz7opa3JnhWL+Wm4F/q70j1cB/fC1N\nPCEcT0vI7RmOf63lDV4Dvb4mvz/+fYmv9yi8dnWkc2580Du8I8uBVc65q2rp1webi67XOXcngJnN\nDd512wC3heh5kegB4Xi/8oQCuc4OjXPuRufcW5xzuwO7A68CLsHbUY+NRBeH46UFyR8UjnFjsVdw\nNXwnvkK+AHgQXgtpqjDrLAZWuEwXw+C3/g94jWZlkwi+t1d9hFXF8I0G2X2By6KP9/Bw3jENspWn\nXr2xWOmciyuiVGMRyx+Kt/F/rkEW/LXcWTv/TuB/Jwn5Rmcx/jkTXJWfj+/tn4c3X8S/r4RTbw3H\nw/GaySkNZZiDb1CbPLtinoO3+3+lIa5yna4qwC1Deb8TGviYfcPxkkR+u4VjvVGvU3klnR6FLcGP\nR50+VXxKg745XiM8teUdHW8Ygnv37kTvSen1Btfz95vZpfj38Rb8c/tDELm8Vt6bnXPLB5Rtg0Wu\nsz0SejufNbNv43upj42iNwrHuQVJ1ivC8/HP7Jv4CuWdNblzMtLcqLAM++Irn4ta4i+IGoADQtkm\n+cyb2Rx8T/hbUfCBwIVVr67Gwfie3VU1+XqP7kC8XfmyhjTq8pVWNuVazGw7/ODySVHwAXhtrV7h\nLMF3tKr7vR3e/PIV/ID1IKp8lwA/dMGuUeNg4OKGPOs8DFg6QK6qEP8UjlVv+HcNsgfhtaQLEvlV\nz3cBEz3uSZjZZngz7CpCZ8HM5obynDzgeg/EV9DVvXkofnxuyrscGoZFTNZuljBVM8m+XjPbKsjt\njO/c/QGvkdwHvBI/HljvsNzvtApQYzFdrMV/XKuisMvwdvH98V5BORwEXO+cuy78X/WoHon3ZLol\nkoM8zeIyYH8z29U5d3WGfPWBr61HmNnB+AHzd0XBS4CLnHN318QfirfVx2XcFN87rqf7EPxcj7Or\nCiZ81IuZ2htfgq/QXS2NJvmB1wK8JBxPjc7fDe8BVad+v7cIx2uccz9vkI/LNQdv1pgycdLM/hFf\nsTZpC3Uczc9kI/x8hyuBP4fgqvfe1Jk4CN9gpxqni/CmswMY/P5+Eu/Z9gbnXPXuD6z8o7JdEJk3\nF4Rjkzn1/+Ib6djkVF3buQ1hOdf7SrxmfqhzrjJ9VvfxOPwzXRHCHoBvrJq05g0emaE6YmaPaZnh\n/C78vT0xCqsqgKPDxLx6evPN7M1BDa8m7z2E6CNwzt0KvAVvd/9sdPqBeBvsxRlFr8rx8VBx1cux\nhZnFE6ouxn+4h9TktsZ7Oq0APh3C5uEbj0EfKUz+qJfiJxruEqU7F+/ZVa8UBtmKd2eyfbxN/sJw\nrF/Lw/Fa2s+dc78KwakKZ02U3jXh/2eGSWuTCGMXG8H4uMRy4NHBnFXJLAQ+Fv7Nmex1IbB30Ihi\nPoCvoI+KGtAlTLg6x+WaF2RzOhknhOP769doflmWD+G9iL4B/HcUPfA+mtnG+PclfkZLw/EfarIP\nwY+dwdSe/momm9FKrrcyr8VmKcM3fLvR3AhJsxBFvBd4uJl9F6+63oFXZZ+N1x5+gp9XAYBz7lQz\n+xzeRfJiM/sa3sSyAF+xHQbc6Zz7aDhlf7zJaNKH7Jz7SENZDgD+5Jy7L6Pcn8J/iM/BVzan4Cuv\nLfDzKA7Df9gfCPndFMr9qlDmX+JdUo/EawaHRmMZ++LfqaYK9kB8bziuCI/Fm+rOMLNj8L3tI/Dz\nMyA9uA2+cnmSmb0Obz5Y6pz7zQD57+F72x8xs0UELQtvOrkCP5ZQz2/QtYz3hp1zd5nZp/Dumeea\n2VfxjejOeNPbI51zsbPDsfiK70wzOwHvEVTdz6ZrbOKDeNv+GWZ2HN608mz8sz3aOfe12rVcEA8W\nBx6GN0kmG4vw/h4PvBQ43/x6XdfiK9Tn4+eMHIP3Mou1vCX45950TfsyMc5T5bPMzH4KPMvMTsQP\n1u+Od6W9Ff+9xJ2i6truawjLud6zCO7a4fo2x38bVUc6lr3fDm4Dcp3t+gOejl/a40J8JXUPfpzi\nR/iPxwac92x8Q3IT3oxwLd5m+gFgSSRXuRs+M1GOyq/9MwVlH8N/fGfhP8A1+AbjLLzW8qCa/Dy8\nO+byUOar8V5fi2pyLw9l+buGPH8LXNoQ/jp8Rb0Gbzr5MPDWkM4BkdxX8QO2Y7XzHwn8PsQ5JuZg\nDJLfGe/ivBLfI70Q36hvUpP7Wjh/o1r4AvyYzOcb7umLwrO8GW9euwrv/fbcmuwc/Kzzq/FeZhcB\nb8P3yB2wZeZzfBq+o7ImXM/PgKfVZOaHd/PzDedXEzmnPK+WPF+I7zCsCmVfBnx5UBp41+gpzz3E\nvSzk/8ha+Nb4saObwvv5feD/4McZYnfleeHaPjfM9Yb3rZpn8yf8t3d4kH16wzsxUjPr19fPwk0Q\nQoiRxcz2w1fkb3XOfXimy3N/RI2FEGJkCOMha1xUMZnZFnhz1G7AXs65m2eoePdrNGYhhBglDgde\nZ2anAdfhzawvwY+pPVsNxcyhxkIIMUpcjx8LeT1+sHkFfjLf0c5PQBUzhMxQQgghkmiehRBCiCRq\nLIQQQiRRYyGEECKJGgshhBBJ1FgIIYRIosZCCCFEEjUWQgghkqixEEIIkUSNhRBCiCRqLIQQQiRR\nYyGEECKJGgshhBBJ1FgIIYRIosZCCCFEEjUWQgghkqixEEIIkUSNhRBCiCRqLIQQQiRRYyGEECKJ\nGgshhBBJ1FgIIYRIosZixDGzM83MmdnjekpvWUhvcR/pzdYyCLG+MLMTwvt+xEyXZRjUWAghRgoz\nOyJUrifMdFlSmNniUNZlM12W6WbOTBdAJHkRsAC4uqf0ngDMBf7WU3pCiHbeBnwIuG6mCzIMaixG\nHOdcX41Eld6VfaYnhGjHOXcds7yhAJmhesfMHmhmXwh2+TVmdrOZ/cTMntYgOz4eYWZPNLOfmtkt\nIeyAukzD+Tua2XFmdp2Z3W1mF5vZW8xso0HjAjnhZvYUM/uVmd1uZreZ2Y/N7MAB1/vsYJO9yMxW\nmdlqM7vEzD5mZtt0vpEbMOE+O/O80szOMbM7zOzWmtxmZvZ2Mzs3PIu7zOx8M/t3M9u4Jf2nmtn3\nzOx6M1trZtea2S/M7LUNshub2evN7I9RHn82s3ea2WYN8uMmIjPbwsz+28yWh3f9SjN7t5lN6YSa\n2SZm9hoz+4OZrQjv67VmdpaZvT2SOxP4Uvj3xdG9mmSWqr2vzzOzs8P758xsy/g+D7hHreYjM9vO\nzD4Y7sUd4d5cYmafM7N9g8xRwFXhlN1qZV0WpTVwzGLU7n8rzjn9evoBjwJWAQ64DPg68Avg3hB2\ndE3+zBD+OWAdcB5wEvArYL+azONq5+6CN005vEnpm8CPgNXAKcCyELe4dl4q/GjgvlCGbwFXhPA7\ngAc1XPO9wG3A74L8D4EbwjlXAds2nNNYhvvLL1y7A44B7gHOCO/KryOZRcAlQe464AfA94GbQtgv\ngI1r6RpwbIi/D/hteJ9+DlzvP/dJ8psAZwX524DvhnenyuPPwDa1c44Icd8BLgrpngz8DLg7xH2h\nds5YKK8DVobrOCmE3QDcHcm+FTg7yF4BnBD9Xt7wDh0Tjr8Jaf4R2CK+zwOeweIQv6wh7qDoHb4h\nXOvJwDnhvh4V5J4R7lf1fcRl/ViU3glB5ohRv/+t7+1Mfzgbyg+YDywPD+QDgEVxjwJuD3FPjsLP\nZKLiOGJAupXM42rh3wvhpwDzo/AHAddG6S6unbcsEb4aOCQKnwucFuKObyjfc4FNGu7FceGczzWc\n01iG+8sveja3AEsa4g3f+DrgY8C8KG5L4Mch7r21894Uwq+upwtsBDy9FvaxIH8+sF0UvhDfgDng\nm7VzqsrKAafW3r2D8Z2HdfGzBQ4J8n8ENm0o1+MH5HFCyz2s3qG1wKFt93lA3GIaGgtgc3znywEf\nZ2qDvAg4KJVO7ZwTaG4sRvL+D7yOmf5wNpQffiDa4XuDYw3xR4X4n0dhZ4awH7ekW8k8rvaCrsP3\nJHZoOOeV0Qu1uBa3LBH+oYb0Hh7iriq4H5vge80rGuIay3B/+UXP5q0D4p8S4s8k6nRE8TsCa/A9\nUAthc5nokT428/ncEeQf3RC/Z6h47gN2jcKryuo2mrXG74f4F0dhzw1hn8y8P1UeJ7TIVO/QlM5I\n/T4PiFtMc2PxhhB+emZZG9OpyZxArbEY5fs/6Kcxi/54bDh+zTm3riH++HB8tJltVIv7TmFef4/v\nfZ7lnLu+If6kwvRiftQQdmk47tR0gpntHeyu/2Nmxwfb8mfxvb5tzGyrIcqzITPouT85HE9x4WuP\ncX7A9HJga2CvEPzw8P8VzrmzMvI+CNgUuNI59+uGPK7Am0jG8O9bnXOccysawpvelfPwld5LzewV\nZrZdRvlyKf12UjwpHI9vlRqeWXf/1Vj0x87heNWA+Gvwled8/Ecd89eOeTWe55xbhR876cLyhvRu\nD39OGlQ1szlm9kW87fQTwL8BLwFeHH4LgujCjmXZ0Bn03HcPx/+pDZqO/4CHBpltw3HXcLyUPFLv\nK8DSmmzMlPckUL0r86qAUPG9Dq/9fBa4wcwuDx2Lp5qZZZa5idJvJ0XpfezKrLv/cp0dDVZ3PG9K\nrzOiSbvJoeS81wMvxdt434AfUL3RObcWwMyuxZtMhqkMNlicc4Oee6V5nsHgSqHi5iq5XgqVT9H7\n5Zz7jJl9G3gafq7P3+M7Fi8BTjezJznn7u1Qjq7fzqCO8vq+j11Z7/dfjUV/VJPcdh8Qvwu+Z343\nfmBzGK4Nx12bIs1sIbA+TD/PCcdXOOe+XyvDpsAO66EMGyJVA3GSc+6LmedU83EelCmfel/juF4m\ncAaT6XHhh5kdjPcCewLwMuDzfeQTcQ8w18w2c87dUYtbNOCcq4G98ffxjz2XJ2bW3X+ZofqjshO/\nwMya7utLwvHXHXtQMWeH4yFmtn1D/POHTD+XB4RjU+/3cKRRdOXH4ficVqnJnIPXMvYys8dkyt8J\n7G5mj65Hmtke+N7nOrwbde84535PqLiA/aKoteE4bGe26lQ9uCHu0AHn/DQcX5qZR9eyjvL9b0SN\nRX+cjO8BPBh4T2wHDC34m8K//zVsRs65pfj5DPPxdu1x+6SZ7Qm8a9g8MrkkHF9Zu94D8PM1RDdO\nww9KPsnMPhE0xUmESWUvrP53zt2DX1IC4EQz268mv5GZPT2SX81ET/LTZrZtJLt5iJuDH2QfahUB\nM3u8mT25PlnM/MTCfwz/xnlUPem9h8kXP48A4B1x3mZ2KN5s2sRx+HktTzCzD1tt8qOZLTKzg6Kg\nFfgGY/sSR44Rv/8DC61fTz/g0UxMyrsE75V0OulJeY9rSbNRBm+CuibEXYOflPcDvA33VPzAnwN2\nqp23jAKX2ih+ihsifv7I2uh6v4G3s98brr1TXhv6r+leNsjsClwYZG8FfgmciJ+4dVkI/13tHGPC\nTfM+4NfhOfyM9KS8VUxMPlsRwi5g8KSwEwaU+6gQf1QU9nomJoT9PLqOG0P4pcCWkfw8fIVdzQ34\nMr4Sf0nJO4TvuFXuqZeHa/sDvrd+NIMn5T0iugfXh+9pyqS8SP7UKq1wbccRuaCTNylvZO7/wPs5\n0x/OhvbD2xmPxVfWa/HjEz8FDmuQPZOOjUWI2wn4Ynih7w4P/R3hY1sTXuz5tXMaP7LUx8eACg44\nEO9ueyNerf5TeDnHuua1of8G3csGuU3wXixnhw99Lb7X/VvgfYRZ/g3nPRNvyropOud04NUNshvj\ne9nn4CvW1cBf8NrpZg3yXSqrPYH34Hv6y8O7eiO+4v53YGFDOvvj5wzcHN7jSXnmvkPh/fwxfl7C\nnfiZ3k8nMT8C75jxMXwnaHU4/2LgM8A+Ndmt8Q3Ecvw4yaR0GdBYjPL9b/pVE3rEBkSwgZ4NXOic\n23emyyOEmP1ozGKWEuY4LGkIfzDwhfDvl9dvqYQQGyrSLGYpYUXK2/Hq+CV4NXk3/MzQOXhb6BOd\nH/gUQoihUGMxSwmeDe/D+0g/EL/A3F14u+o3gGNcmBwnhBDDosZCCCFEklk5g9sGbGgiRF8452Zk\nQmHTu920dE/VyRtuWaXJaQ0oT7ZMKs2ctAalV0+3a56p+5Uj15ZPaVlLn1+cd/3ctriYdevWdXpp\nZmVjIcSGSm4llVMJ5sqXVlhdK6mSNFNMR2XbFJYr38d9rTdC8fl9NDzDWpHkDSWEECKJGgshhBBJ\nZuUAt8YsxHQzU2MWY2NjWe/2sN9tqbmrtCy5JphhzUmlppcm004TXcs/zL1rM4F1pakMXccspFkI\nIYRIogFuIUac0sHT0t56Xb4Pj6Gc8weF5aRb6g3Vlk8qLCeuiab7VHqvu5arq6bThjQLIYQQSdRY\nCCGESCIzlBAjyHQ5nnRNt82ckTto3EabqaaPiWi5prwcE1tp3DBmq655t5nAuiLNQgghRBJpFkKM\nEH3OTG4KG3ZWcVMZU737HPfVYTSENup5x/93HRAvLVfqenPSz8mzD8eBNqRZCCGESCLNQogRIsdF\nMibHPbMprHTdqNJ1qdrKPV1aREkZBpWnTQtKndv0/6A0ShdozBlv6HN8oglpFkIIIZKosRBCCJFE\nZighRoiu5qfU4GZOWKkraddB5j7SL93XopQ+lopvG8zuw8W4NE25zgohhJh2pFkIMeLk9Myb6DrR\nLZVP193zSvJLlaN0TaX6+XFc7sBw1wH3PjYxynF80KQ8IYQQM44aCyGEEElkhhJihJiJZb9zBsub\nyJUvHZQtmZdROnu8j7g2hlkCPYe+N28qQZqFEEKIJNIshBghcmdktzHMAGyX85vOLZ0hHpNzvbnu\nqKWbH6XOT+WTKsN0uvwOo+HlIM1CCCFEEmkWQowQfWyHWbovQul6SIPOS6XRxxaiOeMrw1xbmybS\nln5XjWGY55yzRldqld0SpFkIIYRIosZCCCFEEpmhhBhBSk1BfQxK5+QTyw/rBjoo/Xo+uYPSOaav\n0k2H+nYqGGZtp0Hnlw7Cd0WahRBCiCTSLIQYIUpdZ0u3Nu2qDbSVp++ebY4rb6nLba4GVrJJVCq9\nUkeD+nkpuRy0NpQQQoj1ijQLIUaIrm6rqR5qn6vADuvi2RSXqw2UakjDaFCDypUqY9u5pa6/XfOe\nDqRZCCGESKLGQgghRBKZoYSYJZRuudo1/VzTSxula1vluua2mYlK12zKke97Xacc81Mf11Farhyk\nWQghhEgizUKIWUIfLqFNaQ271WeTfOmaRzM5qFuq8QyKz5UZpjxNcn1qhG1IsxBCCJFEjYUQQogk\nMkMJMUtoGwSO6bqmUu7mRDnlaoqv4sbGpvZR161blzwvDss1qZSardrmQfQxszqn3KWmv9I0uiLN\nQgghRBJpFkKMEG29+6Zebh8bFZW6hA7bW91oo43G/77vvvuAydpGpWXkak1tcaUbBOXMdM99DsP2\n7ktX1k1dh1adFUIIMe1IsxBihOjD9bKrppAbN6y7aKVNwGQto2KLLbYAYO3atcDk8Ywq7N577x2Y\ndxNd18ZK9dBztlWNKd2KNieuD60pB2kWQgghkqixEEIIkURmKCFmIaUDsTnrLfWxbHabfFO6c+b4\nKmizzTYbD9tkk02ACRPVPffcMx53ww03TIprIjZzdS1jH+atepq58n0sAT8dy5ZLsxBCCJFEmsUG\nwnT3KsT6ZZhJZ20rk3bt2ZZuJdqWXqwV7LLLLgA84AEPGA+bO3fupLBKm4CJAe41a9YAEwPdAKtX\nr56SfqVl9KEp5AxO9zlZsHRSZelKv6VIsxBCCJFEjYUQQogk1ufmGOsLM5t9hZ4muvpti3acczNy\n88bGxore7dINkaZjlnPKXFKFVQPXO+yww3jcdtttB8Bee+01HrbnnnsCEwPblXkJ4IorrgDg3HPP\nnVLGyjR1++23j4dVZqh4kLyNnDkkTfK5g8w5S8U30XV59CbWrVvX6d2WZiGEECKJBrhnKcOuXilG\nk9zeZcnzT6VVSul6SJV77M477zzpCPCIRzwCgO233348rPp7/vz5ANxxxx3jcdtuuy0ACxcuBOCC\nCy4Yj7vzzjuBiQFygJUrVwITGkbT6rZtDOM4MuymR6ktVHMG1TWDWwghxHpFmsUsouv4ktxqZx8p\nLaJ0raB6Grl7Y7TRVq64d7/jjjsCE2s+LV68eDxum222AWCPPfYYD6s0iy233BKATTfddDzu2muv\nBWD33Xefks+f//xnYEKTgYnVbKvzmsqfuyZTbg++LY2u53VdH6xPpFkIIYRIosZCCCFEEpmhZgGz\n0b1ZdKPPjXKa0i0dGC/dRKcyC2299dbjYdVg9E477QTAbrvtNh5XhcUD3JX5qTJfVQPdMDGru3KT\nPfjgg8fjqrDKHBWf27SWVDXo3TZoXLplbGpQup5W26zr3O9ea0MJIYQYGaRZjCjTpU3InXa06frc\nSyfS5bq95vaUKxYsWABMHmSuBqirwexFixaNx+2zzz7AxIQ9mNAyKo2kqTzVxL2mdaBuuumm8bDL\nLrsMgM033xxonrA3zEDydG/zWj9vJpFmIYQQIokaCyGEEElkhrqforkXo0npktVtcrnzLHLSbEsr\nNjlVcfH8h2pjo+oYD2ZvvPF34xq7AAAKPklEQVTGAGy11VZT0qvyjDczqtKt0moaSK9MVDAxv+Lu\nu++eck2VCattVvcwy36XLhXfZqLKNRG2xWkGtxBCiGlHmsWIITdZUafrRjlNlG4D2pZ+k3tppQXM\nmzdvPKwa9K56/tUGRjCxGmwcVslXGxvFZa7kqnWgmmZrx1R5Ll++fFL5YEKjyF1RtymsbfZ7qTYw\nqAyDwtriSh0TcpBmIYQQIok0ixGjdEKO2DApHVNK2be7ahQ5rqFNvfWYarJctQJsdYyJ96yo/q60\nk1jrqFagrfK+5pprxuOuu+46AFatWjUl76ZJefXriMl1bW3bU6LPyXI54xOlz7sUaRZCCCGSqLEQ\nQgiRRGaoDYRhZqGK2Usf26rmkDuIWpmQbr311vGwahC6MivFJqHKbLX//vuPh1WmposvvhiYvA1r\nNehdbYJUyQBcfvnlANx1113jYStWrJiUZx9bxralUWpy6sMxoWvepUizEEIIkcRmYy/UzGZfoYeg\n6+Y3bWgiXjvOuRm5Qbnvdk4PuHRNoq405RNPyqvWZaom4O29997jcdWmR7H2EGsG9bSqyXhVWhde\neOF43NKlSwH429/+Nh5WaTjVmlCVq25MqUty6eTIUnKeX1PeqXoi0kA6PXxpFkIIIZJozGIWUNoL\nlNYw++ljolXOkhKDwkrKkRovq9xXK22gmiAHcNtttwGw6667jofdcsstwOTVaetU7rSxZlFpD22u\ntk3lHsa6Uro1a/28lFybfMn5fSDNQgghRBI1FkIIIZLIDCXECNFmJipdMyjX1FQik8qzIp7VXZmF\nKrfaeD2nSq4yPcVh1Uzs7bbbbjzu5ptvBiZcYitXWpgYvI43OKrT5LbbRt8bI+VsQtWWd1t5Us9F\nq84KIYSYduQ6K0QDM+U6OzY25kL+42F9rDE0rMts6eB33IOv/q6O8Yq0lVttpTHAxIB4tdVq08qy\nVd6VLExoJPEAdxtNmkXX+5S7llTJPiSl7tApKvl169bJdVYIIcT0oMZCCCFEEpmhhGhgps1QtbJM\nkSsZsE6l1fW83MH1uhkqNivFA9QVlRmpkos3Naqv8RTn02Z+alpOva3cXU1B0zUzvp5fU3xqMF5m\nKCGEENOOXGeFGCHaeqgpV8qcdJto01LaXDZzy1ANJFfHWANoO7dpW9UqjSqubVOjmJRGMYhS19mu\ns+C75F26PtiwSLMQQgiRRJqFECNIyt1yOlad7XMb0D4mjzX1iuvaRs7Eujit0rxzJ0KWag99uMLW\nyziM5pmDNAshhBBJ1FgIIYRIIjOUECNE6YB11wHuPrb/bDOXdJ21HNOn+a1Jvklm2BncbWUdJJcT\n19Vs1edWq9IshBBCJJFmIcQIkuqplqxe2hTfdQXVVFiu++36Jnfwvm2Au3SDoxy5Ui1omLyHRZqF\nEEKIJGoshBBCJJEZSogRpG3GdD0+FTdIrh62vvz1U+Ssz9T0f9f9sJvybvq/9L7m3LPc8nedg1Fa\nnjakWQghhEgizUKIDYDSLVebaBvULR1sHcbFs608bbPBu7ro9sF0a159OCQMu06UNAshhBBJtJ+F\nEA3M1H4W1bvdx54STee29fLb8uy69lFpuVLpD7smVoo+3VdLNaNB56fk2s5tOk/7WQghhJg21FgI\nIYRIogFuIUaQlPmj1JyUM+M71yTU5GKbI99ETlycVpubb1OaOWafPkxNbfRh6s/ZWjflmCDXWSGE\nENOONAshRpDcyWNde4ulmkLM2JjvY7ZtVVqqdZROvOs68N51Y6FUGqVux6Vuvm15tqWRK5+DNAsh\nhBBJ1FgIIYRIIjOUECNE6VLaueSYrXLnTeTse91meildnyl3lnYb07VJVM48jlJngra8S+eQ9Dmz\nXJqFEEKIJNIshBhBct1XpyPP1EBoVy2ldAvVtrS6DmynVnltc0dtyjtnILxNvo9r66otliLNQggh\nRBJpFkKMENOxD0GcRk5vN5XfsBMCh6F0j4g+5HIovRfTsSbfdK/zJ81CCCFEEjUWQgghksgMJcQI\n0dUltimN0rjcWeE55o71tfR230uUdy1X23lts+Vzy5ozgJ6a8T0s0iyEEEIkkWYhxAiS6l3muGAO\nOjcVN8wWqk1p5LijtuVfumptrnzX9bWGmRjXdm3DagMpjU1rQwkhhJh21FgIIYRIoj24hWhgpvfg\nzpArSrfE5FJqAptuhhkELk0rx2TW597jbaRMVKXLnFdoD24hhBDThga4hRghcledLV2TaJBMU96l\nlLrC9rHGVc6Wq4PCSijd3jZXc+n63HKdHKZjprg0CyGEEEmkWQgxQuT2hHPcXXO3Ki3JJw4r3daz\nbfJYrntpTvn6XJ23VEvJdQvO1SDb4krXyZLrrBBCiGlHjYUQQogkMkMJMULkumoOO2Bb6uLZ5wY+\nbZsNNZ1bev1dB4+b8ip1NBhmG9au7rTra0l2aRZCCCGSSLMQYgRJuUF23aI0lVfO+V1Xru1zK9HS\n9atyKV3Hqi2/YbdCbUor12khx3GgFGkWQgghkkizEGKE6Kox9OEKWy9Dbrqlvdc+et9NlGo1fS4P\n0kbuOMOwrrm54yVdkWYhhBAiiRoLIYQQSWSGEmKE6Lqe0zDmlbYy5MjlDpyWypUOAueYdnIHgbuu\nHjuMS3Lb4HrOc9YAtxBCiBlHmoUQI0TuAHdOjzZ3n4ZhtxfNdfEcZnC2Th8T0XIGwlM985wJhH2U\nq40+J222Ic1CCCFEEjUWQgghksgMJcQIkbuRTc5s6FxTUx9+/Tl0LUPpXIzScvVxn6Zje+q+N8DS\nPAshhBDTjjQLIUaQ3J5t7iqm0zHgOezgcW65pmvF23pcU3zutqo5ecfxOY4Gfawwq21VhRBCrFek\nWQgxQnR1Yx0m/ZzeehulvfXccYb15ULbVI6uvfthtklty6d0bCbHXbkUaRZCCCGSqLEQQgiRxKbD\n5Wu6MbPZV2gxq3DOTY8dKEH1bg+zxlBXum49GpMzuJzrEpqb/rCUmsByXZjb6LoGWC5taaxbt67T\nTZRmIYQQIokGuIUYIfrQKErXMCrttea4lw6zEm3X3nrXe1KqReQOXJeu/Fo6gN6Wdx+TFetIsxBC\nCJFEjYUQQogkMkMJMUKULo2du1FODqXmm1zzU9dNknLKkzsY38c8i1L6NKeVmtimY76ONAshhBBJ\nZqXrrBBCiPWLNAshhBBJ1FgIIYRIosZCCCFEEjUWQgghkqixEEIIkUSNhRBCiCRqLIQQQiRRYyGE\nECKJGgshhBBJ1FgIIYRIosZCCCFEEjUWQgghkqixEEIIkUSNhRBCiCRqLIQQQiRRYyGEECKJGgsh\nhBBJ1FgIIYRIosZCCCFEEjUWQgghkqixEEIIkUSNhRBCiCRqLIQQQiRRYyGEECKJGgshhBBJ1FgI\nIYRIosZCCCFEEjUWQgghkqixEEIIkUSNhRBCiCRqLIQQQiRRYyGEECKJGgshhBBJ1FgIIYRIosZC\nCCFEEjUWQgghkqixEEIIkUSNhRBCiCT/H6kNjXY5UIJaAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "T9q5R5Az7qUO",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Counterfactuals"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "4NA_mpKv7t4P",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# intervening on Shape, posX and PosY \n",
+ "intervened_model = pyro.do(scm.model, data={\n",
+ " \"Y_1\": torch.tensor(0.),\n",
+ "# \"Y_4\": torch.tensor(5.),\n",
+ "# \"Y_5\": torch.tensor(25.),\n",
+ "})\n",
+ "noise_data = {}\n",
+ "for term, d in cond_noise.items():\n",
+ " noise_data[term] = d.loc\n",
+ "# intervened_noise = scm.update_noise_svi(noise_data, intervened_model)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "rffN8m75cKrd",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 254
+ },
+ "outputId": "93cf401b-156a-4bb7-9796-30096fa82c46"
+ },
+ "source": [
+ "(rx1,ry,_), _ = intervened_model(scm.init_noise)\n",
+ "compare_to_density(ox, rx1)\n",
+ "print(ry)"
+ ],
+ "execution_count": 434,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "(tensor(0.), tensor(0.9218), tensor(15.6666), tensor(21.0625), tensor(18.7714))\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAADcCAYAAACBHI1wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJztnX2MpVld57+/menu6a7unmEFVFZw\nZBGCElbFaCIqE1ECKps1otnNGhnQPyS7KkaT9SXujhoXN5ldd+MKiAPbJiu+wKJLcB0RZMRBMDCC\nuuogICPDDi8j6nRXdXf1vDz+8Tzfrl/97veec+tW1dStnu8nqdyq55znPOc599Z9fu8nhmGAMcYY\n0+Kqg56AMcaY1ccPC2OMMV38sDDGGNPFDwtjjDFd/LAwxhjTxQ8LY4wxXfywWHEi4vaIGCLixj0a\n7+5pvBv2YrzDOgdjHiki4sz0eb/poOeyG/ywMMasFBFx0/Tleuag59IjIm6Y5nr3Qc9lv7nmoCdg\nunwngBMAPrZH4z0XwBEA/3+PxjPGtPkRAD8D4BMHPZHd4IfFijMMw149JDjeR/ZyPGNMm2EYPoFD\n/qAAbIbacyLiCyLiNZNdfjMiPhMRvxMR3yz6XvZHRMTXR8RbI+LvpmNfUvuI8z83Im6NiE9ExMWI\n+MuI+PcRcfU8v8AixyPiGyPiDyLiXEScjYjbIuLL5tzvt0422b+IiPsj4kJE3BURt0TEY5deyCuY\naZ2HGHlZRNwZEesR8Q+l38mI+NGI+OPpvTgfER+IiB+KiKON8b8pIt4cEZ+MiEsRcW9EvCMivk/0\nPRoRL4+I96Vr/GlE/HhEnBT9L5uIIuK6iPjvEXHP9Fn/SET8x4iYEUIj4nhEfG9EvDci7ps+r/dG\nxDsj4kdTv9sB/M/pzxentdpmliqf12+PiDumz98QEdfndZ6zRk3zUUQ8PiL+07QW69Pa3BURr46I\nZ0x9bgbw0emUzy9zvTuNNddnsWrr32QYBv/s0Q+ArwJwP4ABwF8B+BUA7wDw4HTsFaX/7dPxVwN4\nGMD7AbwewB8AeGbpc2M59/MwmqYGjCalXwPw2wAuAHgjgLunthvKeb3jrwDw0DSHXwfw4en4OoCn\nint+EMBZAO+Z+v9fAJ+azvkogMeJc+QcHi0/070PAF4J4AEAvzd9Vt6V+jwRwF1Tv08A+C0AbwHw\nt9OxdwA4WsYNAL84tT8E4N3T5+ltAD45/rtv638cwDun/mcB/J/ps8Nr/CmAx5ZzbprafhPAX0zj\nvgHA7wK4OLW9ppxz1TTfAcDfT/fx+unYpwBcTH1/GMAdU98PAziTfr5bfIZeOb3+4TTm+wBcl9d5\nzntww9R+t2h7VvoMf2q61zcAuHNa15unfv9yWi/+f+S53pLGOzP1uWnV17/5uT3of5wr5QfAtQDu\nmd6QnwYQqe2rAJyb2l6Qjt+OrS+Om+aMyz43luNvno6/EcC16fhTAdybxr2hnHd35/gFAM9Jx48A\n+I2p7XVift8G4LhYi1unc14tzpFzeLT8pPfm7wB8qWgPjA/fAcAtAI6ltusB3Da1/WQ57wen4x+r\n4wK4GsALy7Fbpv4fAPD4dPw0xgfYAODXyjn8shoAvKl89r4So/DwcH5vATxn6v8+AGtiXl835xpn\nGmvIz9AlAM9rrfOcthsgHhYATmEUvgYA/wWzD+QnAnhWb5xyzhnoh8VKrv/c+zjof5wr5QejI3rA\nKA1eJdpvntrflo7dPh27rTEu+9xYPqAPY5QkPkec87L0gbqhtN3dOf4zYrwvn9o+uoP1OI5Rar5P\ntMk5PFp+0nvzw3Pav3Fqvx1J6EjtnwtgE6MEGtOxI9iSSL92wfdnfer/bNH+lOmL5yEAT0rH+WV1\nFlprfMvU/uJ07NumY/9twfXhNc40+vAzNCOM1HWe03YD9MPiB6bjb19wrnKc0ucMysNildd/3o99\nFnvH106v/2sYhodF++um12dHxNWl7Td3eK2vwSh9vnMYhk+K9tfvcLzMb4tjH5xen6BOiIinT3bX\nn4uI10225VdhlPoeGxGP2cV8rmTmve8vmF7fOEz/7ZlhdJh+CMBnAfjC6fCXT39/eBiGdy5w7WcB\nWAPwkWEY3iWu8WGMJpKrMH7eKncOw3CfOK4+K+/H+KX30oj4noh4/ALzW5Sd/u/0eP70+rpmr91z\n6NbfD4u9459Orx+d0/5xjF+e12L8p878zZLXkucNw3A/Rt/JMtwjxjs3/brNqRoR10TEazHaTn8W\nwL8D8BIAL55+TkxdTy85lyudee/7k6fXnytO08s/AL546vO46fVJ0+sHsRi9zysA/HXpm5n5nEzw\ns3KMB6Yvvu/HqP28CsCnIuJDk2DxTRERC85ZsdP/nR47XcdlOXTr79DZ1eDCkufNSJ0Jpd0swk7O\nezmAl2K08f4ARofqp4dhuAQAEXEvRpPJbr4MrliGYZj3vlPz/D3M/1Ign+FwezKpxdnR52sYhp+P\niP8N4Jsx5vp8DUbB4iUA3h4Rzx+G4cEl5rHs/848QfmRXsdlecTX3w+LvYNJbk+e0/55GCXzixgd\nm7vh3un1SaoxIk4DeCRMPy+aXr9nGIa3lDmsAficR2AOVyJ8QLx+GIbXLngO83GeumD/3uc1t+1J\nAudkMr11+kFEfCXGKLDnAvguAL+wF9dJPADgSEScHIZhvbQ9cc45HwPwdIzr+L49nk/m0K2/zVB7\nB+3E/yYi1Lq+ZHp915ISVOaO6fU5EfHZov1f73L8Rfkn06uSfv8VrFEsy23T64uavbZzJ0Yt4wsj\n4qsX7L8B4MkR8ezaGBH/DKP0+TDGMOo9ZxiGP8L0xQXgmanp0vS6W2GWQtXTRNvz5pzz1un1pQte\nY9m5rvL6S/yw2DvegFECeBqAn8h2wOkJ/oPTn/91txcahuGvMeYzXIvRrn3ZPhkRTwHwH3Z7jQW5\na3p9WbnfL8GYr2GW4zcwOiWfHxE/O2mK25iSyr6Dfw/D8ADGkhIA8MsR8czS/+qIeGHqfwFbkuT/\niIjHpb6nprZrMDrZd1VFICK+LiJeUJPFYkws/Ibpz3wNStJP3811MeYRAMCP5WtHxPMwmk0Vt2LM\na3luRPznKMmPEfHEiHhWOnQfxgfGZ+8kkGPF13/upP2zRz8Ano2tpLy7MEYlvR39pLwbG2PKPhhN\nUB+f2j6OMSnvtzDacN+E0fE3AHhCOe9u7CCkNrXPhCFizB+5lO73VzHa2R+c7n2pa13pP2otRZ8n\nAfjzqe8/APh9AL+MMXHrr6bj7ynnBLbCNB8C8K7pffhd9JPy7sdW8tl907E/w/yksDNz5n3z1H5z\nOvZybCWEvS3dx6en4x8EcH3qfwzjFzZzA34J45f4S3byGcIouDE89UPTvb0Xo7T+CsxPyvuKtAaf\nnP6fZpLyUv83cazp3m5FCkHHYkl5K7P+c9fzoP9xrrQfjHbGX8T4ZX0Jo3/irQD+heh7O5Z8WExt\nTwDw2ukDfXF6039s+mfbnD7Y15Zz5D9Z758Pc77gAHwZxnDbT2NUq/9k+nBetey1rvSfeWsp+h3H\nGMVyx/SPfgmj1P1uAD+FKctfnPctGE1Zf5vOeTuAfyv6HsUoZd+J8Yv1AoD/h1E7PSn6L/Nl9RQA\nP4FR0r9n+qx+GuMX9w8BOC3G+ecYcwY+M32Ot11z0c/Q9Pm8DWNewgbGTO8XopMfgTEw4xaMQtCF\n6fy/BPDzAL6o9P0sjA+IezD6SbaNizkPi1Vef/XDhB5zBTHZQO8A8OfDMDzjoOdjjDn82GdxSJly\nHL5UHH8agNdMf/7SIzsrY8yVijWLQ8pUkfIcRnX8Loxq8udjzAy9BqMt9OuH0fFpjDG7wg+LQ8oU\n2fBTGGOkvwBjgbnzGO2qvwrglcOUHGeMMbvFDwtjjDFd7LMwxhjT5VCW+4g5u18Zs1cMw3Ag2een\nTp0aputfPnbVVaNM99BDD10+dvXVY/moBx4YXVIPP7xVKujIkSMz/ZkzyfOuuWbrX59j8JoPPjhb\nYIBzyGOQzc3Ny78fPXp02/XyPHheHqu2qevn++AcOT7vFQAuXZq1ul577bXb+l28eHGmfx6jjq9Q\n1hj2z+dxXK5vvibXKd8375Prk+/n2LEx7/bCha1SWHwPVX8ey+NzbhsbG0t9tg+lGcoPC7PfHNTD\n4siRI2Pw/dGtxGF+2eQvNT4c+GWQ/4/5ZcsvGGDrC53jqi8WvuYvPPbLDxdei3PIbbx2/pLi7/yi\n5hd4nfe88fODkNfimLmN65QfRqSul7oOMLs+6sGW71c9yCt5PXl91b+Or9Ymz4f9uOZKOFDnrq+v\nL/XZthnKGGNMFz8sjDHGdDmUPgtjrlRoKlCmkWz7pkmK/bPfgOQxCE1B2aRV/QbZrk+TUb42r6nM\nOMrvwXF5H7mN5pJscqm+DeWP2djYmLkPksev65SpfpyM8ilwXLUWNAnluSpTFq9ZTUj5GM87fvz4\nzLXzWNVHk9s4j2yOUmuwE6xZGGOM6WLNwpgVQkm5SkOobSqiJzs8q1SZpWM6wpVDVTli1RwreT6U\nipUkT41ISd9VIwFm10JpKdmBXrUmtZaLOsQ5Vp5P1a5UNFRewxrplceq2kPuqxzpPMb3Ur0vaq2X\nxZqFMcaYLn5YGGOM6WIzlDEriHJWKpMITRDKnKFyHapzOo9PlNNYoUw6OQeBVCe2Mm3l8zhuvUdg\n6z5pOstjcd45ca3ebzYTKbMb16nlzM73rXJBWvdWx1XzUe+RmisDEWqCZh5XvR/LYs3CGGNMF2sW\nxqwQLelbhUFSws6SrXIk8xilT1Xug2NlRyjHyNpGzTDOEm3LkcxjObOc81LOXKIyvlXYrsp0r05g\nlX2tSpOo8h31PtRcVUmW3L+WW2ndd7620iDrWPn9Vg7x3WoZ1iyMMcZ0sWZhzAqhwiYpTbZCYbNN\nWyWIVYlcJc3x2Nra2uU2HlPJZsou3qpJpLSlVm26Vt2olvaUk9WqhqC0gty/JuNlLYsaUV47HuP4\n+X1RiXR1DVrFHrOGpOZf3+d8nVocMo+xLNYsjDHGdPHDwhhjTBeboYxZQVTWb6tctsreVWYfOotz\nW80+Vm2ZWhY9O6Bz2Cqpe0/0QjzrfLKZqN6ncuBmJ3Atsa4yxbN5pq6BGkuFFtdw33pPhOfSvKVK\npitTodqfgvOpZsR595bXfRmsWRhjjOlizcKYFYKSbJboKblnZ2sNsc2OTHWsbrqTpcwq0SpU9VVK\nzuvr65fb6kZK+V6UtKsSCAn7qeqrnH9u4/rk8dmuNl6iFqQ0ALVOdGbnYwwGaO261wrvXbRCrkr+\nU8l7df753rKWtAzWLIwxxnTxw8IYY0wXm6GMWSFoSsnmBh5TuRSq5hFNG7l/zWTO/asZRtWNUvkM\nKm6f5o9sSqn5DC3Hb56jyoOouQjnz5+/3Hby5EkA2zdvqveY701lv3OtW/Wx8vw5N2U+5O/Z1EQz\nGOehzG/q/VObRLEf55rX6cSJEwDauTk7xZqFMcaYLtYsjFkhqiN63jFKwCoUVjkyKW1Tusx9OJaS\n+Fu1nurGQvn3LAFzfEryuTYUNQOlzdQqrGqsrN2osF2OobLgVahw1dRaW9nm/jWMNY+hwoJV1nXV\nmpSWkqmBDHldOYYKtV0WaxbGGGO6WLMwZoWgHTpLiVWaBrakSr5myZNjZO2B9nzls6hkSbuGhgKz\nUmuWvls+C75mn4KqG8VwV2oK2X/Ae1LnKQlbVcatba1qrFky5xorTUFVC+Zcs8+iagiL7kGh/EQ1\nLDhrZ3y/ehrnTrBmYYwxposfFsYYY7pEK2tzVYmIwzdpc6gYhmF39ZyXZG1tbQB0Zm/LyZlNNdU0\nArRNHDTpqHpFylzFudG5nk1UNInkzOo6nzwHVdeohuvm8evWpvm+uRZq7VQ2taIGE+S1UBsQ1bn2\ntkKtbS3nt6ozpbLyVQ0tMmfb2aU+29YsjDHGdLGD25gVouU0blWAzVKlqrBaE8SyE5jaACXPLI2q\nTXSY8EVNJF+H4+b5UANRSWqqom6tF6W2gFUVbNlPbRqkala1tAElyav3gf14nVy/S43RqudUa24p\nsqbH+2RtLqWxZdSxnWDNwhhjTBdrFsasELVCa/5dJXDV5LzcpkqAUPpU9m2VnKekb46ltvokyhav\nwl2VhlArsiqJuOUPyNI9/QtK21Jhvpwr55Ov3ZL4VZkWVa6k7jGS76NqYOp6uT81QmpzalvVfKw1\n/0WwZmGMMaaLHxbGGGO6OHTWGMFBh86qzYZU1VlltlKbGdU6S9k0wixtFXKbndGkVi9V2dFqPqoy\nq9qkp2ZDt6qlKnNXpn6/5XVVFWarA12FH+frVJNfa4vWfC+cd147tjFzPQch1AqzmZbpUq315uam\nQ2eNMcbsD9YsjBEclGZx4sSJAdDJcK3/VSXJZym6hoRmibZqBll6VRVvq6NaXVslxqmkP+VA5/hs\nU+GoyjGu6kBxbqo/yUl2La1JaUF1LVRNr7yedEpzXvnadT2VM15pD7VPblP919fXrVkYY4zZH/yw\nMMYY08V5FsasEDR7qDLbqty0cm6qGPtqxsgO7pp1rTYiytQaTNnMwmvmvA/2P3Xq1Mx91A18gNnM\n8LwWnKNy1HNN8nw4DzqNW7Wxcr+anwHoAIBqGszrpbLxa85INtdVp7TKale1sFRehjIHOs/CGGPM\nvmPNwpgVQjkmW5nGPJYlSOWUJqoWU3Uoq604s0RLyZ+VZdWmOyrsk22qamum1k9S2gPvLWswvE4e\nMzvHAe2MV857JZkTtSFSHROYDcPN11caoao2S3jfykHPtc5zVUEO6tydYM3CGGNMF2sWxqwQlBKV\ndKmSwWqyXUbZvNVYlDhroh+wJRXnrVCpUSjpuNZ1UveW4VgbGxuXj9V9GvJ98H5VwpvSjOpWq3le\nVUvJ/bkW58+fn+mvNDzlz6h98rl8VaGtao+PVq0npXmq/TjUe7ITrFkYY4zp4oeFMcaYLjZDGbNC\nKNNONTcAs9nQyjyhwjhVZnjd9CibLmooaR5XzVVt4FNNLsoJnKlzVE5gtdmQMivR/EQzVx5LlWmv\nW9hm57RyiNdggtzGMbL5je1qU6nq/FZ1oJS5sW7tmsdS4dbLYs3CGGNMF9eGukJQEopZnoOqDXXs\n2LGZqrN8P1UdJDqIlZNWJdexf5aqeSw7c0mt65SPkSyxco5qPkrSVuGiVSpWDmu1Pana0pXzYYht\ndvJyjLyuXIsa7gtsaUvKaawS8Oqcga21W2RTKVVlWGlD6jpKy+Lvly5dcm0oY4wx+4MfFsYYY7rY\nDHXIWbRstdkZB12iPEPzRDb31HyDjPpM0PxBM0Y2pTCHgiYYZUpRuQ4k9+cclbmkbvwzb3yV1U1q\n2W/l8M3z4bVUZnyrnLoqd64czvW9UeY6tdacR2ufbWXuyhnrNBuqbHaOn/NjUl0wm6GMMcbsDw6d\nPaQsohG2pE+zmrS2EFVbZFL6zFoHw11VJjY/C+yTr6lCMGuVV3XtLNG2UFK02jJWbWlKavZ4vg+V\nIc4xuHVszhRX0j3H5byyg5j3udMaS/k+atXZLPlz/uyf25Qz/sSJE9vmqGpD5fdNrc9OsGZhjDGm\ni30Wh4i9eK+sZSzGQfks1tbWBmC7fZwSY5aiVZ0i0qoXRVTCXmvrThVCWrcUzf1V2C6P5XujpKwS\nxlR4bNVqVIJcnmvt1wrDzXNtJeXl8WstrFaF2XystT8FUftTqL00VHXbWu8rX/OBBx6wz8IYY8z+\n4IeFMcaYLnZwHwIOo6nQLAdNItkRS8dqdlBWc08vhLbWNcomi+xIBbacyICuYVQzstX8s/mj1rFS\nmcmqppIq413JoaqcdzYr1W1qsylP9a8bQWWTk5pP/d/MbcpMR3jf+fzqxM5rqMqQ81qLmCRzv2Wx\nZmGMMaaLNYsVZb+0CYfTrjaUqpVUqRzPraqtWdKsG+qoWmIMxVRbtKqtRHle1ngoiStnvKpSSxbd\nLrQ6vZUTON93rSKb56ok/+osVlVb1WZGqsYVf1chzFwnrjmwlWSnrt3SHLkGeQ35fuVre/MjY4wx\n+44fFsYYY7rYDPUoxSXNVxNlzqhZv/mYqlek6iDRBMFMZpWLoMwUtW5Unlt1ROcxlENc/a3yLNTG\nPaTuYa3GzeaY1r7W9X4yyhmvckjqtVX2uzLT0XSUAxlameucYza7cdyWQzx/BuzgNsYYs+9Ys1gx\nHCb76EaFZyoHJkNG1SY3qkotoWSbx68bHCkpX4V4EuXwVe0qo7nOOV9fOcRboaoqhLRuKKS2h1U1\nq2p123xMSfBKO1FVZ6vDnZoeMKuJKA1JVfNVGl7NFM/9lsWahTHGmC7WLFaMll3UXPm0aiupkFDS\n2pMhj6vqCLXOo5SbpVIeo6StkuDy+DWMU0nAyr/CZMScNMhrsi3P9f7779/WBgDr6+vbrpnv4/Tp\n09vml69FDURVjFWhvPTpqP9btccF10kl/amaWCoculbIzVpTS4NcFmsWxhhjuvhhYYwxpotLlK8o\nO31fWk7AnZxrRg56W1XlZKZJpbbnPoDO0qdZ4uTJkwC2m0bYVjdIAnSoaR0/O3w5Rs7gJsqMQ5NW\ndhDXelHKFFRLidd5E66ZWhNV9vsxj3nMtjkqs4+aa3Wk5zY1Rg0qyOOqdeV9Z5NcXSeVsa8yuL2t\nqjHGmH3DmsUhQL1HLW1gkffU2kSbg9Isjh49OgA6Ea9XK4jUyqm5v5LIax+VRJb7U6JVznglkdca\nTD2HeE2uU2vB/nksajOssZTn2Nr8SGkbdftWYKuOU6vuVUZthcq1U8EBbFOhs0SNpSr3Ks2Fc9zc\n3LRmYYwxZn+wZmGM4KB9FipsMkvANWFN2fWVFKqk6ZoImM9TZSZqaGu+tkrsa7XVfTbyMaVR1ES6\nfG1qFErKb2kRqpov55PDcLPGVcdQe3zwPrOPRr2X9d5Um3qfW3uaqDXk+ttnYYwxZt/ww8IYY0wX\nZ3Abs+K0TE01FDOTj7GfMvvQXELHrQrZVFVwaV7Jpg6aavIx9lOhpGzLJp7q2M0Vb2lKUaYXVZ9J\nmdGIqlVVa1vlUFVVt4tOb5VhnZ3LhGYttXYMJlhkk6V8LueonOXZBJY3WloGaxbGGGO62MFtjOCg\nHNzHjh2b+WzzfzQ7Wyl9L7L3A6D3Z6jjKyet2nq0aghqK9EsMbNf1TDmzb86r7OmUesnsR7UvPlU\nTSH/vex3Xz6vJa2rEGZCrWOR7VJzv6wp1M+F0oLUZ2B9fd0ObmOMMfuDHxbGGGO62MFtzApB56aq\nz6TMJq3tRVvbbLa21c2O8VZJczWHWoIb2HJeKycw59EyD+X7qE7g1pazwGzOhcrByP1rPSqVg3H9\n9ddfPsb7raXNc39VG4propz9qq6TysHgxkm8dn6PlPNebZy0E6xZGGOM6WLNwpgVQkm+reqrqt6S\nclRXB7fK9qU2kENVVcgpJVil8fBYvjYlX1XlVdWqqtnHqsaVGksdW6QGmqpcW0ONAeC6664DoDdv\nUpoX551DaKsmmJ3x9Zr5vvm7qiTM9yjfR32Pcr9lsWZhjDGmizULY1YISp7Zvsw9KHK101b9JCVV\n1n0psoTKMWhvb20Dmq+txlI+iFY9JCUVK+2H1O1OW9VzM2pMtc1rDXdV+1Oo2llqS1qifEAka3F8\nzzlG/gxwDBWmrLaA5bn53nJo7TJYszDGGNPFDwtjjDFdnMFtjOCgNz/KzkiVmVxNO/n/mKYONYYK\nqayhmtn00qo7VK9X7mPm2sqZzTblqFYmKv6uyrBzrso0VcfuXZttdGoDOtCAIbAqI1tttcp1qaHM\ned5q0yT+njPGa52vXF9Lfa9z/I2NDWdwG2OM2R/s4DZmhVBbm66vr29rA2bDSrNTWjk8eS4dn0oK\n5VjZka42OOIYTArLzvMqaWdUwqEKPa0SvAptVduS1uq5+Z6UpK2ke4536tSpmTZK92rb2RZqcyiO\nkdeO1+ZaqPc7j1XvV2lge4k1C2OMMV38sDDGGNPFDm5jBAfl4F5bWxsAXa8oUx3VKtdBbVhE84cy\nn6iS4DR1ZHNPNRNlh3cri1qZcdS91XpUKnO9dX7uz3aaapSDW5mCcv5Dbctj1HLteZ14H6pWFa+p\nNqhSG1q1MtbVRkc8lt9nvpfnzp2zg9sYY8z+YM3CGMFBaRbHjx8fgO1SIiXCfKzWIlL1lrKGQEe1\nCrOsNYaUMzv3V85lojbpqRnVSjJvaSdKq6khtHmueZOoqj2cPXtWzoMwVFZpWZyjcogvuglV1crU\nGqpNopRmUUOL1edDBQ6cPXvWmoUxxpj9waGzxqwQqtZTa18EJclTC1B1hFTIJqHGkOegxmppNapS\nbA2BVcmFSoJXlV+rnV75RpQPgm2sswVshSQzTDbPg+Or+1Dz4TWzBqYk/lYV3LpXR4Zj5DbOn8dU\nVdnz589f/l19tnaCNQtjjDFd/LAwxhjTxWYoY1YImoJUSGU232RzB6C351QmD1XHqZq51GZGizq4\nVThqNYHlcE5eO5tIeK4an3B98ljq2tWkk9eEjvDch/NQZi61rtWcp8KVldmqbl6V56Gy8rl2eSzW\nieIaqJpbrTXcKdYsjDHGdHHorDGCgwqdPX369ABs1wBUKGVLYlQbI1H6ZLKZqiOkksiIkqZV0pza\nZIjj8jpZkleJfdVRnaXvVqiqCjHmNVtbj7bqainpXjno1aZVagOlet8q8Y7ka1OLyJ+L6hDPnwm2\nqfFdddYYY8y+4YeFMcaYLjZDGSM46NpQmVZ2cCsHQ+VGqDpCtb8yd6n+/O5QeQGqVpUyjais61o6\nXDm/lYNYbQRV27IZSt1vzbNQuQsqt0OZ31T9rTqeGkvte672Ma/5K6oulWJzc9NmKGOMMfuDQ2eN\nWSEoMSsHtpKKKYWqzW6yFMt2Sruq+mrdXhXQWg3b6SxXlU3VdqHKId7aBEhVz63SdM+53spaVluz\n1nVXmotCZXer7XAXoZVdrzZ7atWZytdWlXR3gjULY4wxXeyzMEZwUD6LkydPDsB2KZF7MaiQ05Yt\nW/kZOK4KVVW271qRNqNCdCnJZm2DEi01i4sXL87eeKJqIkqr6fkSCO+Na5fDUdXaVY0tw/chX7tq\ngnnNW1u5qnBXnqvOa4VKUzOoot2gAAAH4UlEQVTKnw+uU06mZL/19XX7LIwxxuwPflgYY4zpYge3\nMStE3VoT2DLjqJLjKhOYZPNQDTXNppdqHsrXqduG5nZVS4rk/hxDZYgrxzBpbRlLk4vKdM+ml2ry\nUmGyefyWE1tds66nCl/NYcH1Otksxvd5Y2MDALC2tjZzbyqQgfPKbarc/G5dDtYsjDHGdLFmYcwK\noWol1UQ0YNZRrWoSKVQbpWOV6KYk/xpyqrbuzFJxDddVznVVn6mlueT+9dqqXlRLqlZtKny1VqTN\n81eVcqnhtEKRWfMJ2NqoSFXnVWG4NWw6X1uFQat72gnWLIwxxnTxw8IYY0wXm6GMWUF6TtfqCM8m\nCFUjqdZBym21NlQ2V9DUkedAp7HK4FYb91Tne3Y600yU+/Bezp07B2C7qaaafbJZTTn5qykrm2Vo\n7slz5fpUJ36+Zh6D8+erytZWpeJrvSw1hjJf5VyKWiqeeSCZPIb34DbGGLPvOIPbGMFBZXAfP358\nAHSGspJClcSsHMOUKinV121Z8/gqE7hVqyprHSqTmVRJeF7/VgZ3nXc+jw535QRu1bhSVXNVnal6\nnronta2qGrdu4pTb+JpDbqk1qCAHpREqDZLnnjt3zhncxhhj9gf7LIxZQdQ+B2orVEqXaj8L5Weg\nxJn9BpSK1dagiuobyFoQJWBVk0hVnSUqXJfntfZpyGOp8WtFXYVa19bWsnk+lP5ViLF633huK9yV\nY6yvr8/cW14n+lqUFtqqM7Us1iyMMcZ08cPCGGNMF5uhjFkhVBgkj508efLyMWb70gSRwz9pGsnm\nj2qiaZmJsqOU4bEqY1pt/9ly9LaczNlcUh3OLSew2sZUmWXUNqPKhFdNTWpzJRXSqsqEc83yWtdM\n71ZJc3Vevl9mybfMUaqW1LJYszDGGNPFobPGCA4qdPbYsWMDoMNXsxRaHbAq/FMly3GMLIXW7Ut7\nG/KosEyiNhSi1K0kX1XDqDqBlSSvtALlEK+aiKrmm9d1kTpWLbLWVO8DmK3n1Kr1pDSqTHWg5+sw\nkTFrhJyPNz8yxhizb1izMEZwUJrFqVOnBmC7tKvs55QqKY1mCbImteVzq409t6kKrcrPUEM8VQmN\nPP+qbWTJX+27UO9baToqgZDjK41KaQVqLeo8cptKTKyVaPNa1BIr+VidQ0b5fdiPe10AW74p9s/3\nzXOVb+rChQvWLIwxxuwPflgYY4zp4tBZY1YImg+yCUnVWWrVLlKZzC1zcysLuVXDSGVH1+1e87mq\n0qqq/Fr7q2xztVGTCsOtJjNlElLO9Wrmm3dM1Wyq5PutwQEq9FfV0GJbrhdVPxcqCEGZM5fFmoUx\nxpgudnAbIzgoB/eJEydmPttK22iFl6oQWP6ukseqo1ppNUqzYFveb4JwL4rc3goJzd9DNdRW1Y2q\n28rOG78VXqq0E6Kc63V++fqU/HN/tqm6XSqZksf43uQ2NVeuf9Y26nVUMuLm5qYd3MYYY/YHPyyM\nMcZ0sRnKGMFBmaFOnz49AG3TS25XNYZqtnburzK9q3NWmUayaadeO5+vnMw1b0A5oLN5qH4nKcet\nMh0p81vNdejNtTqeW+XRczuvozaCajmW85jVsZ9NbNWMCMyWis9rSLOYqi9lM5Qxxph9w6GzxqwQ\nDMXMUqNyStc6QqoeUpaiq/agnMzKwc1+KgSTUisr4OY5Zmm91pdS4asq1FZlj9fzVI2r7PCtGznl\nsfi7CnvlvWVJno56pWGosZR2UrdyzdpA1eLyddQ6tTaCqkEI8+5zJ1izMMYY08WahTErhNo/gpJ+\n1gZqmGiWsClNZkmy2s1Voh9RdaAyvFbec6POq4VKUstUDUSNSclfhcLme6vnqrDgLJFzXNZgUrWV\nVM0mNR8171rbStW94ph57Xlt5QtapBYYoN+vnWDNwhhjTBc/LIwxxnSxGcqYFYImjpwVTZOCKiuu\nHNCkFbKpTE00Y+RaTKrUtcoor6iy33xV25hm8wrNKjTFKXOOOo+0SpTndaKTWTmSa02pPIYyndW6\nTsD2NSOt96SWO1eZ4q1tWPP7pkxZaj47wZqFMcaYLtYsjFkhKE1mCVJVka01j7K0yzGyhM1w0uqQ\nzf1VyC3J0j0lWGo/KhxVhbS2ajcpTUeNVSXm3jamdeOo3J/3ndeirqdKFlQJiiSH7SqtR4U6V+rW\nrnk+LSd2b4vZ3WLNwhhjTBc/LIwxxnSxGcqYFYKmnZxnoUwKNFPRSZtNIyrfoG48pDKNq2krozKB\nVdZyawMl3oeqz5Sv2bo3/s71Uaa5HBzAceuYeQwVOFDvNaNqNvHaeXxmtucxa4Z7Hp/mPZrF8n1z\n/nmutZ8yWymT3LJYszDGGNPlUFadNcYY88hizcIYY0wXPyyMMcZ08cPCGGNMFz8sjDHGdPHDwhhj\nTBc/LIwxxnTxw8IYY0wXPyyMMcZ08cPCGGNMFz8sjDHGdPHDwhhjTBc/LIwxxnTxw8IYY0wXPyyM\nMcZ08cPCGGNMFz8sjDHGdPHDwhhjTBc/LIwxxnTxw8IYY0wXPyyMMcZ08cPCGGNMFz8sjDHGdPHD\nwhhjTBc/LIwxxnTxw8IYY0wXPyyMMcZ08cPCGGNMFz8sjDHGdPHDwhhjTBc/LIwxxnTxw8IYY0wX\nPyyMMcZ08cPCGGNMFz8sjDHGdPHDwhhjTBc/LIwxxnTxw8IYY0wXPyyMMcZ0+UeDTZyBe1y4awAA\nAABJRU5ErkJggg==\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "WwvqXamE4hqb",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 276
+ },
+ "outputId": "364feed3-eb07-427f-964f-88c5b07bc240"
+ },
+ "source": [
+ "rxs = []\n",
+ "for i in range(100):\n",
+ " (cfo1,ny1,nz1), _= intervened_model(cond_noise)\n",
+ " rxs.append(cfo1)\n",
+ "compare_to_density(ox, torch.cat(rxs))\n",
+ "_ =plt.suptitle(\"SCM intervened on shape\", fontsize=18, fontstyle='italic')"
+ ],
+ "execution_count": 435,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEDCAYAAADEAyg+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJztnXm4JUV99z+/GWBmWAYIssm+iYIC\nA0aNoPCoIRCXhKjv6/YK7vK6gK/6alwQlIhRkKhRURExURQXUONCBIQgisZBQGWTVQZZhp0ZGGaA\nqfxR1ffW7anTVd2nz9xzh+/nec5z7q36dVV1n+6u+i1VZc45hBBCiCZmTXcDhBBCjD/qLIQQQmRR\nZyGEECKLOgshhBBZ1FkIIYTIos5CCCFEFnUWohVmdoKZOTP7i+luy2MFM7vSzC6a7nZ0xcwWmtkf\nprsdYjjUWUwjZraxmX3EzC41s6VmtszMbjKzs8zsjQOOOdDMvmZmNwT5pWZ2iZkda2abR3LzzWxl\neLH/94CyNjSzxUHmPjOzgmYvAG5wzt09xDkfbWb7dzn+sYaZzQN2AS6Z7rZ0wczWAp4MXDbdbRHD\nsdZ0N+CxipntDpwNbAScBnwZcPgXw98BfwN8MZJfF/h34B+A64CvAzcCGwDPAd4HHAg8LRyyN2DA\nMmA3MzO36gzMo4H54e9LEvkpDmpxmin2BT4EaKRZxu7AbODS6W5IR3YD5qDOYsajzmIaCCP4bwHr\nAU9zzv2hlv8OINYS1gJ+BBwAfAQ42jm3MjrkRDN7KnBwlLZ3+D4TeAWwI76Tqcp8IvAW4AzgfwMX\nl7TdObeiRK6BZ4TvXw1ZziqY2RxgRWGnN1PYM3zPSM0C2Ct8/25aWyGGRmao6WEP/Ijrp/WOAsA5\nt9I5d2uU9H58R/FJ59xRtY6iOmahc+4jUdI+4fuU8P2U2iEnAjcD54X/f5trtJkdHExWL47SXhDS\nDjGzNweT2LJgJntjJLeRmblwLgCLwnHOzHaulXeWmd1tZkvM7Ndm9vxEW35tZpeb2e5mdoaZ3QPc\nAxwVyvzLxDEbm9ldZnZWLX3fUMZiM3vQzC4zs0MTx38ntGsLM/u8mf3ZzB4wswvMbLeE/Hpm9gEz\n+124JreY2WlmtnVCdn0z+6iZ3RhkLwrnsAfwKPD75I8ytYxZZvYmM/vvcO3uN7P/CAODWG6DYKI8\n0cyea2bnBNnFZvavZrZ2rq5Qzh7hfK43s4fM7HYz+3l8f+DNlgDXh/O7Plzji81sv0SZrzSz75vZ\nn0KZfw51bF+T2z78zsea2QfN+3UeCvfdh8xsdqLsot9ZDMA5p89q/uBNMQ7/gl43I7sJsBS4DVin\nRR1XAn8C1gFWAB+I8l4Q6v8H4NPh7ycWlPm+ILtDlPbBkPZLvLbwbuDtwPXASmCvILcx8CrgXvwo\n+VXh80rAgszHQlk/AY4I5fwmpB0Y1TkbeBC4JpR3EvDGIP93Qf41ifafADwC7B6l/V/8y/iXwLvC\n/2eHMt5YO/46vOnvJnwn/GbgE6HM3yZ+tz8AS4BPhfYdi+/Qrol/d2D9cC8sC7JvCOXfje8k/lDw\n26yD1z4fBU4N9R0F3AXcCTw+kn1W9JstAo4B3gT8LKS/vaC+/YHlePPYe4HXhfvjHODwSO68cF6X\nAd/Ga7NH4+/p24G1I9lZwJ+BfwP+f7gOJwEPh2tpkewhoa23AtcC7wHeFu4tBxxRa2/x76zPgN98\nuhvwWPwA88KLx4WXxzfCg7FNQvaIIHdsi/LXDw/GmeH/S4HTw99rA1cD54b/LwwvtFkF5X4buKeW\ndkZo35fiMoDnhfQ3RGkbh7RPJMp+dch7c+Ja3YjXwqq03YPsCmC/mvwOIe+ERPpy4AtR2gHhOh1X\nkzXgIuCPUdqG+M7vYeA5Nfmvhbw5Udq5eM1tp5rswaF9r6gdvxx4ek228mN9reC3+Vw4l0Nq6QeF\nMj4cpb09pF0CbBClrxuu6dcL6vsRcAOZAUy4vx3wf2rpx4b0XaK02cDcRBkfDrI7R2nHhLSLgPWi\n9I0IA5Iuv7M+gz8yQ00DzrllwF8BH8c/TC/DO7P/ZGY/NrOtIvHnhe/vtqhiL/worTItXcqkGeoI\nYCfgSDObhbeJX+oSpq0EC1jV0boAP3p9a62M5eF7WZRW+VGmmLzMbB3gn/Ej2++Y2eOqD96vczne\n8R/XCfBx59yFtfbcCNyPj8CJOQ7/IjwqSjsBPyo9sVbnJviR8E7hGoG/pobvbH5WK3s5Xrt4OJzP\nC/FBB8cA99XKvjocs0uQ3RPvUzrZOffrWrlnh+9Gf4WZbYfXDL7pnDuzln1++N41Squu3xucc0tq\n5wFTf7NBbIzvQFcxv0Xt2h7/8v6+c+7fa9mr1OWce9Q591A4dq6ZbRKu2V1BZE50/F74DvpQ59wD\nURn34v1v20eybX5nMYjp7q30mRj1Hg5cgR8t/TjK+31IazRX1cqrtJHnh/+PxL/ItgXuAz4b0p8Y\n5D5VUOZ8/MP5yShto3D8yQn5w0PeU6O0d4e0XWuyfxPSmz7xSPGEkPaEAW39BXBz9P/TQtvfH6Xt\nWlDnPZH8kSFtv0R9vyYyFeE1xVzZRwTZj4b/d0qU+6qQ95zUeUZy7whyz0rkzQt5p0RplwLXJmSf\nFGTfVXA/vAhvSnLAQnyEW/13rUxFL0scfzq+U49NS7sAX8GblurX6xFgXiR7E3DegLadByzq8jvr\nM/ijaKgxwDl3A/B5M/su3o777Ci7ctQVOR0D9RH8pfjIt9PxD90Ha3IlkVAL8CPreJRbRbqkIpv2\nwXdQsWN2b/wL5o812aqcVwKLB9R/Z62cPzvn6uVU/A54pplt5PxI8wS8SeiTiTrfyeBInaU1+Ufw\nL8YJzEeq7YGPbotlL8PbxgdR1bk3cLdz7rqEzNPDdy5sttIar0jkVRrWZaG96+C1gW8lZKugiOz9\n4Jz7QdAcDsF39u8BPmRmRzrnPh3EKg2mrjFVdU2Ea5vZ0/CmuzuBz+J9bkvwnfxJwHLnNXLMTwjd\nBn8/TyGc3wLggpDU9ncWA1BnMV6swD8c90Vpf8SP+PZk8gHIsQ9wm5uMqKpeNs8A3uYmJ9RVL4ds\nJBSTD/5vE2mpl8s+wOXOueU1+UurF0TEhuH7N865awrashfN16J6KTzZzDYF9sPbzGPzSlXn751z\n5xTUuQC4wgUzScTuwFymXpcN8RMXS8pdD98JTyG8EA8FbnL5CZDV9UyFNb825H8vau/aDP7NoDBM\n1zl3J95X9aUQ4fVbvDms6iz2wo/ab4iPM7ON8KbQ/4iSj63aEJ+vme2CD/v+eiRbdQCPJJr1Ovz1\n/0b4v+3vLAYgO91qxsz2M7P5A7KPwv8m8YPxb+H7OPMT8+rlzTWzd5vZBuH/dfHmpYmXVxhdvwcf\ntvr56PC98TbjKwuaviDIXlVLW05tgp35+Q67s2ontCM+QqvOteH7FamKberM9B3w5q+m0W/VWSzA\nR1hdzNRr2rbOOfjReNMLNj7Xa4G9zWzXurCZzQkvy4rrgY0tCvU1MwOOx0+4LJmMd3n43r9W1wvw\nUVGnOOeq657r4K8P98tAzGyzRPKD+MHnoihtAemOJ+W72g64vdZRrI93/ltNtuos6ue7Nf4ZuoRJ\nraP4dxbNSLNY/XwYeKqZfR8fFroU2Ap4MV57+E+8/RcA59wZZnYSPkzzSjP7Gj4KZV38Q/Mi4AHn\n3CfCIXviTVdTXtTOuY8n2rIXcJlz7tGCdi/Aj84eTaQ9XJN9Cn70Wu8srgcOMrMj8E7L651zv8Sb\nRD6AN2PshtcaDO9j+Wu8eaIy6TS97CoqP88HgM2A/RPazPl489lhZrYlPlx3eajzWfhQ5ZcG2Sfj\nn5VUnXuHuuKX+seBHwC/NLMv4iPfNsHb5A/Bv+SqF/KX8JFgPzCzf8HfDy8DnhDyS0b5X8Zfn6+a\n2aeAW/Dh2a8G/gt4ayS7gMlIqAlCB7UX/v4bSDDz3GRmFxB8Q8DjgcPw4btHBblNgK2BbyaKSXUW\nFwCvN7PT8eG32+C1otsTslVQxdbhOfoR/hk6HB/19BI3GWxxPuW/s2hiup0mj7UP8EL8w305/oZ/\nGP9A/AR4OZHDr3bci/EP8p14c8Mt+Ifgn4AFkdxb8C+DQzLt2CnIfbagzXNCO0+K0uaGtC8k5N8U\nyv6rWvoz8PbrB0L+MVHepviJglfhNZi78RrCp4kcp/gZ7A7YMtPmG4LcmQ0y6zO59MhS/Av8SvwL\n/C8judenzifkXQRcnUjfF/gxft7A8vD9M/z8gdk12ZeF+2E5fmR+EpPmo78vvK+eBPwQb+dfin+5\nHpGo68IB7a0cwf+YqWc+XutZGO7fh/DzRk4mctIzGTr98kQZ3wj3wOxauV8J9/eScK0OxPsvVgLz\nI9nf4zuUPcP5LMP7uk4Btur6O+vT/KkmQwkhxNgTTIJLgc845/7fdLfnsYR8FkKImcRT8CbB7PIn\nol/UWQghZhJamHCaUGchhJhJVDO3U3NKxAiRz0IIIUQWaRZCCCGyqLMQQgiRRZ2FEEKILOoshBBC\nZFFnIYQQIos6CyGEEFnUWQghhMiizkIIIUQWdRZCCCGyqLMQQgiRRZ2FEEKILOoshBBCZFFnIYQQ\nIos6CyGEEFnUWQghhMiizkIIIUQWdRZCCCGyqLMQQgiRRZ2FEEKILOoshBBCZFFnIYQQIos6izHH\nzM43M2dmB/RU3o2hvO37KG+mtkGI1YWZnRru98Omuy3DoM5CCDFWmNlh4eV66nS3JYeZbR/aeuN0\nt2XUrDXdDRBZXg2sC9zUU3nPBdYG/txTeUKIZv4R+Bhw63Q3ZBjUWYw5zrm+OomqvOv6LE8I0Yxz\n7lZmeEcBMkP1jpntYGZfDHb55WZ2l5n9p5m9ICE74Y8ws+eZ2U/N7O6QtlddJnH8lmZ2spndamYP\nmdmVZvYeM5s9yC9Qkm5mf2tmPzezJWZ2v5mdZWZ7DzjfFweb7BVmdp+ZLTOzq8zseDN7XOcLuQYT\nrrMzz+FmdrGZLTWze2ty65vZ+8zst+G3eNDMLjWzd5nZOg3lP9/MfmBmt5nZCjO7xczOM7O3J2TX\nMbMjzWxhVMfvzOyDZrZ+Qn7CRGRmG5rZp8xsUbjXrzOzD5nZKoNQM5tnZm8zs9+Y2R3hfr3FzC4w\ns/dFcucDXwn/Hhpdqylmqdr9+r/M7MJw/zkz2yi+zgOuUaP5yMw2M7OPhmuxNFybq8zsJDN7cpA5\nGrghHLJdra03RmUN9FmM2/VvxDmnT08f4JnAfYAD/gh8AzgPeCSkHVeTPz+knwSsBC4BTgN+DuxR\nkzmgduzWeNOUw5uUTgd+AiwDvgPcGPK2rx2XSz8OeDS04VvAtSF9KfCExDk/AtwP/CrI/xi4PRxz\nA7Bp4phkGx4rn3DuDvgc8DDws3Cv/CKS2Qa4KsjdCvwI+CFwZ0g7D1inVq4BXwr5jwIXhfvpHOA2\n/7hPkZ8HXBDk7we+H+6dqo7fAY+rHXNYyPsecEUo99vA2cBDIe+LtWNmhfY64J5wHqeFtNuBhyLZ\n9wIXBtlrgVOjz+sT99DnwvcvQ5kLgQ3j6zzgN9g+5N+YyNsnuodvD+f6beDicF2PDnJ/H65X9XzE\nbT0+Ku/UIHPYuF//xvt2uh+cNeUDzAUWhR/knwCL8p4JLAl5B0fp5zP54jhsQLmVzAG19B+E9O8A\nc6P0JwC3ROVuXzvuxkz6MmD/KH1t4MyQd0qifS8F5iWuxcnhmJMSxyTb8Fj5RL/N3cCCRL7hO18H\nHA/MifI2As4KeR+uHffOkH5TvVxgNvDCWtrxQf5SYLMofT6+A3PA6bVjqpeVA86o3XtPxw8eVsa/\nLbB/kF8IrJdo13MG1HFqwzWs7qEVwIFN13lA3vYkOgtgA/zgywEnsGqHvA2wT66c2jGnku4sxvL6\nDzyP6X5w1pQP3hHt8KPBWYn8o0P+OVHa+SHtrIZyK5kDajfoSvxIYovEMYdHN9T2tbwbM+kfS5T3\n1JB3Q4vrMQ8/ar4jkZdsw2PlE/027x2Q/7ch/3yiQUeUvyWwHD8CtZC2NpMj0mcX/j5Lg/y+ifyd\nw4vnUWDbKL16Wd1PWmv8Ycg/NEp7aUj7l8LrU9VxaoNMdQ+tMhipX+cBeduT7izeEdLPLWxrspya\nzKnUOotxvv6DPvJZ9Mezw/fXnHMrE/mnhO99zWx2Le97Let6Fn70eYFz7rZE/mkty4v5SSLt6vD9\n+NQBZvakYHf9jJmdEmzLn8eP+h5nZhsP0Z41mUG/+8Hh+zsuPO0xzjtMrwE2AXYJyU8N/1/rnLug\noO59gPWA65xzv0jUcS3eRDILf7/Vudg5d0ciPXWvXIJ/6b3WzN5sZpsVtK+Uts9OjoPC9ymNUsMz\n466/Oov+2Cp83zAg/2b8y3Mu/qGO+VPHupLHOefuw/tOurAoUd6S8OcUp6qZrWVmX8bbTk8E3gq8\nBjg0fNYNovM7tmVNZ9DvvmP4/kzNaTrxAXYPMpuG723D99WUkbtfAa6vycascp8EqntlTpUQXnxH\n4LWfzwO3m9k1YWDxfDOzwjanaPvs5Gh7Hbsy466/QmfHg2Udj1tl1BmR0m5KaHPckcBr8Tbed+Ad\nqoudcysAzOwWvMlkmJfBGotzbtDvXmmeP2PwS6Hirqq4XhpVTqv7yzn3WTP7LvAC/FyfZ+EHFq8B\nzjWzg5xzj3RoR9dnZ9BAeXVfx66s9uuvzqI/qkluOw7I3xo/Mn8I79gchlvC97apTDObD6wO089L\nwvebnXM/rLVhPWCL1dCGNZGqgzjNOfflwmOq+ThPKJTP3a9xXi8TOIPJ9OTwwcyejo8Cey7wOuAL\nfdQT8TCwtpmt75xbWsvbZsAxNwFPwl/HhT23J2bGXX+ZofqjshO/0sxS1/U14fsXHUdQMReG7/3N\nbPNE/suHLL+UvwjfqdHvy5BG0ZWzwvdLGqWmcjFey9jFzPYrlH8A2NHM9q1nmtlO+NHnSnwYde84\n535NeHEBe0RZK8L3sIPZalC1ayLvwAHH/DR8v7awjq5tHefrn0SdRX98Gz8C2BU4JrYDhh78neHf\nTw5bkXPuevx8hrl4u/aEfdLMdgaOGraOQq4K34fXzncv/HwN0Y0z8U7Jg8zsxKApTiFMKntV9b9z\n7mH8khIAXzezPWrys83shZH8MiZHkv9qZptGshuEvLXwTvahVhEws+eY2cH1yWLmJxb+dfg3rqMa\nST9pmHrx8wgA3h/XbWYH4s2mKU7Gz2t5rpn9s9UmP5rZNma2T5R0B77D2LxNIMeYX/+Bjdanpw+w\nL5OT8q7CRyWdS35S3gENZSZl8Caom0PezfhJeT/C23DPwDv+HPD42nE30iKkNspfJQwRP39kRXS+\n38Tb2R8J596prjX9k7qWCZltgcuD7L3AfwFfx0/c+mNI/1XtGGMyTPNR4Bfhdzib/KS8+5icfHZH\nSPs9gyeFnTqg3UeH/KOjtCOZnBB2TnQei0P61cBGkfwc/Au7mhvwVfxL/DVt7iH8wK0KT70mnNtv\n8KP14xg8Ke9p0TW4LTxPq0zKi+TPqMoK53YyUQg6ZZPyxub6D7ye0/3grGkfvJ3xS/iX9Qq8f+Kn\nwIsSsufTsbMIeY8Hvhxu6IfCj/7+8LAtDzf23NoxyYcs9/Ax4AUH7I0Pt12MV6svCzfnrK51remf\nQdcyITcPH8VyYXjQV+BH3RcBHyHM8k8cdwjelHVndMy5wFsSsuvgR9kX41+sy4A/4LXT9RPyXV5W\nOwPH4Ef6i8K9uhj/4n4XMD9Rzp74OQN3hft4Sp2l91C4P8/Cz0t4AD/T+4Vk5kfgAzOOxw+CloXj\nrwQ+C+xWk90E30EswvtJppTLgM5inK9/6lNN6BFrEMEGeiFwuXPuydPdHiHEzEc+ixlKmOOwIJG+\nK/DF8O9XV2+rhBBrKtIsZihhRcoleHX8KryavB1+ZuhaeFvo85x3fAohxFCos5ihhMiGj+BjpHfA\nLzD3IN6u+k3gcy5MjhNCiGFRZyGEECLLjJzBbQM2NBGiL5xz0zKhsPTeri/nkxv0VfIpuaa8Nm3I\nlZ+iRL6pXbk2lNRdujRVSr6elqq7tP0lZbVt1wD5Tvf2jOwshFhTKXmpx/mpl0jpS3xQXouXzsD6\nSusuoW1nU5rf9QXf9uXfVGdp3SUdwqitRIqGEkIIkUWdhRBCiCwyQwkx5rQ1T/RZflezTKkpq63J\nbBSUmvfamuRKymh7nUrNVl39UE1IsxBCCJFlRobOKhpKjJpxj4ZKHLdKWtdnu9RJO2w9ufr7LHcY\nJ3m9jKYopdJosBJtoFTraGprqs6u97Y0CyGEEFnUWQghhMgiB7cQawClE7hSNMmnTC8lczyGMZuU\nOO07TERrJd/Urqby+6SPCX59Is1CCCFEFmkWQsxAmhyrXUeVwzh6B8nk6umqnZRqMF2XRSmlJNS2\n7RIoTfWkymiqZ5g660izEEIIkUWahRBjTtuw2LZ+g65llfozhpnQl/o/V0+fa2gNOxrPlZHSsvoI\n8+37WJBmIYQQogB1FkIIIbLIDCXEmNPWfDCKsMk+9q7oWlbbGdPDLPs9rNM4V/ews7Snc80taRZC\nCCGySLMQYoYwjKN30HGl9ZRSMnod1Q5wJe1JldXUntwIveRa97FhUemmWE3lKnRWCCHEyFFnIYQQ\nIovMUEKMEW1nAtdlhqkz5yxuc1yf8qVl9LEuVZ22Zp8+5mWk6mlrkmuaA9MVaRZCCCGySLMQYoxo\nO0ItCS/tUm5TWfUyS48rDV8ddmvTpva01Z5y9dRnXfe57lUuLdWeUSLNQgghRBZpFkKMIW0ndw06\ntk1ek3zb8NU+/RjD7OHQ1sfR57paTXW2nSzYVYvrc48LaRZCCCGyqLMQQgiRRWYoIcacrqGgXU0p\npTOsUzLDLnce53c1m5Q68Uuc68OY09oGGvQxS7uLTCnSLIQQQmSRZiHEDKF0NNp2XaASJ23bDXza\najVtQ4VTdTe1p6QNubQ+nNIlq872eQ3l4BZCCLFakWYhxAxkFEs9lIbjlozuS+m6PEbpyH/YpT/a\nlpWT6boicB+T8rTqrBBCiJGjzkIIIUQWW13rivSJmc28RosZhXOuv+VDW9D13m4bWjnM1qajps36\nSSnz2KxZk2PgRx99NHlcafkxfaxm26aMvtsQhSR3urelWQghhMgiB7cQY86w6yfFZXQ5tn58SYjn\nMJpOXW727NnZ9sXUtYm4/D7WfErR9Xy7ri2VYxSaoDQLIYQQWdRZCCGEyCIzlBBjRB8O5T7MVm3l\n6sRO5orKPBSbldZZZ52BZVR5sVmpas+KFSsAWGutyVdYldZkqhnVOlap/4ddX6qPIIS2s9mbkGYh\nhBAiizQLIcaIUWxaM6jcrjSNZCutIdYe1l57bWBSU0hpFrGG8Mgjj0xJq46P/37wwQenyAKsXLkS\ngGXLlk2kVVrJAw88kD2f+JzaamXDbM06KG0Y7W8UIc/SLIQQQmSRZiHEGNF1RFi62mkftF0ddc6c\nOQDMnTsXmOqn2HDDDYGpPo6ddtoJmNQQYq1j+fLlwKSmsHTp0lXqq2QAFi9eDExqGJVfYxCVXOmW\nq13Xkmoz8TCWKw2B7kP7qSPNQgghRBZ1FkIIIbLIDCXEmFNiEhnVdp6pMts6gSszUmV+2mKLLSby\nKtPUdtttN5E2f/78Kd+x2apyjt9xxx3AVJPT/fffD8CiRYsm0ipTVuX8TjnEc+2v08fs9LZO6T7M\nVjJDCSGEGDnSLNYQ+px8I6aP0hFh1wl3bUecXR3usVO6+nuDDTYApmoRlZax7bbbTqRttdVWAGy5\n5ZartLUKna0c4ldcccVEXqVZbLrpphNplQZSaSe5Z6PtxLh6ftftYUvrbKvVaFKeEEKI1Yo6CyGE\nEFlkhprhNJkoZI6aefQR319SblP5paaaVN2VAzp2SlfzLHbbbTcANttss4m8nXfeGYAdd9xxIq3K\nrxzc66233kTevHnzptRdmaoArr76agDuueeeibQ999wTgIULFwJw9913T+Sl1q+qqOZbpK5T6dpQ\nTbTdz7uPZ1kObiGEECNHmsUMpc3IRBrGzKbP2bh9OFbrdcd5lWYRawDV6L9yZsdaxC677AJMah2w\nashsXH71d2oF2/XXXx+ArbfeeiKt0jKqvLhd1WzuOJy2fh6pFW9T7SkNd22zIvAw26qO4pmXZiGE\nECKLNIsZRNfRpMJq13yGmZDV5/4XFfFovQqZrXwPO+yww0TeJptsAkwd8VehttV37Fuo2pHyN2y+\n+eYA3H777RNplZZSUbpFa32NqJg+/RJtt7ttas8wW+uWIM1CCCFEFnUWQgghssgMNQPocwMTMd6U\nmpDahMKW1lk/vrSMWKYyD1XhsjC5/lP1XTmbYdL8FM/4rkxFVVrKfBI7nisq01dsoqqc2A8//PAq\neamymtaLKqF05n3XLVfbtkObHwkhhFitSLMYU0alTSicdmYwjKZQPy5V7jArp9bl421P66GtMDni\nr7SI1Eg+Lj/WMnLtSWkY1ZarAEuWLJnSxlhzSE28q5ff96S8prWk2gYYtHWOa1KeEEKIkaPOQggh\nRBaZoR6jaO7FeJP7TZqc0iW0NXm0jflPOZKreRbxPtgpuTb7TafMSvfdd99E2kMPPQRMboKUWscq\nbk9b006dYfbNLimr6dhR78MuzUIIIUQWaRZjhsJkRZ0+Nr4poa0jvdIKUhsdpVadrVZ8jTc/qkJa\nK0c0rOqMjh3o9VDhSmOIy6pmhcftqLSH+BpWjvd4VnfKYV5REjCQu3Z9Pt9N4dOjqFuahRBCiCzS\nLMaMUUymETOPtvtZ5Gjas6JpZNp0P6YmsFUj88pXEFNtcXrzzTdPpFXrRqXWhlp33XWBSY0hrrP6\nvvfeeyfyqm1VFy9ePJFW1x5SmkN8Hm3XySrZVrXtxLtSTWF1I81CCCFEFnUWQgghssgMtYbQdj0f\nMd60XV48psnU1HXdqFRZlYM7Nu3UHcoAV1111ZQyY4dyJR87uCuTVJUXm4mqOqutV+O677rrLgDu\nvPPOibRrr70WgAceeGCVuqty4/Dd1IzyOqWmpq6hqqWhtsOG5rZFmoUQQogs0izGlFFPsNFEvJlN\n282MStYkatI6UmkpZ/HSpUufKjwsAAAIX0lEQVSBqaGz1SS5yrG9fPnyibxqVL9o0aKJtGryXkUq\nDPfss89eRbZybC9cuHAi7bbbbgMmHd1xqG2KYTcsKi2rj7W52h6n0FkhhBAjR5rFDKBrmKR47JD6\nzUvvg657K1TysYaR2ra00jYqn0IVEgtwzTXXAFO1h2oSXhV+G5dfpW200UbAZOgtTPo9brnllom0\neihvvN1rybmVhhiXlDUMw2qSfSDNQgghRBZ1FkIIIbLYTAyzNLOZ12gxo3DOTYstb9asWS7UP5E2\nnbP6m+puctKmwmM33HBDYHJ71ZjYNFWFzlZmpdgMVa37VJm0YqqZ3nHYblVGnFZRlRufR9e1oVIy\nJVvkDjPjuy1VXStXrux0b0uzEEIIkUWahRAJpkuzaLq3204G6zN0c5gyKod1FfYaO7OrdaDiSXlV\nfqWdxJPmKu2h0gBSzvUmLSKVNuqRf9tQ2z5CczPHSbMQQggxGtRZCCGEyCIzlBAJptsM1XZtqFyM\nfZvnvHT1gK4msHgzo3j58YrUXI0SUqamellNDmxoP1O6Lp+7zk3zOOplDDP7uum+kINbCCHEyNAM\nbiHGkLarwnY5tmIUm+40hZCmtImYptF/XVPItTW1QVO9jamVXOsyg9JKQmebGOZar67VZiukWQgh\nhMgin4UQCdYEn0XpWkHD7pfRtyYyqL5hyuq6nlMujLXtb9OlDam8LuVGx8lnIYQQYjSosxBCCJFF\nZighEky3GWqE5QOjW3uqNBS0ia5mrq5LdQ9j5mq7WdLqft+mzl+hs0IIIUaGQmeFGCO6TgordYY2\nhdq2LautFtF1El+pfNe6+3SI5/JG4fSuy8RyfWoy0iyEEEJkUWchhBAii8xQQowRfc4HaDJ/9Lnp\nTts5Hk0zpgeVUdLGrmtWrc49tds6xNscN2qkWQghhMgizUKIMSQXbllPy408S+TbjpxLV0cddkOh\n0tDTtqu1th2tj2o12Dbtaqq7bchwW6RZCCGEyCLNQogxpO9Jc8OuvVQ6ah125JzKb3stSv0SfU5G\nbGpHH6P8tmX0WXeFNAshhBBZ1FkIIYTIIjOUEGNIbsb0sKaTrsuSdzm2JC9F03mXmpCGDcMtlW9b\nRh+zu7vOHu+KNAshhBBZpFkIMUMYZhRa4vDsGrLZd2hr05pNJQyj3ZSsqdRVs0qV0bRta6k2l6p7\nFE57aRZCCCGyqLMQQgiRRWYoIcaIUcfk97GEeNflspvqaaKPzY9Kr2vTOlYpmWHPre2x07nxkjQL\nIYQQWaRZCDFGtHVYV7QNtW0KqWy71lOOphVvu4YFlziIu9CnY3jYWfNtjxv1yrTSLIQQQmSRZiHE\nDKGtrbzrnhKl/ok+t4DtcxXc6aDtFrMlvqA+/BOalCeEEGK1os5CCCFEFpmhhJghDLNBUNdZx12X\num4yiZQ6uNs66lN1N8k1yacoqStn9qmfU6mpqcRs1dZU2BZpFkIIIbJIsxBihtA0Uh1m0lxJnX1M\njGsaAbcNCe3TuZ46to+w167nVKoptF2BV5PyhBBCjBx1FkIIIbLIDCXEGFG6ac0oNz9Ktae0DX3M\ndWizz/Ywjve2dXf9HUpNR02z8ptYXbO6pVkIIYTIIs1CiDGiz7DXlNyoVnxtqqdkpduuq6P2qQGU\n1hWXP+ys6xRtZ4OXliUHtxBCiJFj07meSlfMbOY1WswonHP9zWZqwaxZs1yofyJt2P0jUnJt138a\npu6u5XddlyqmxA8wqndgSV19aiItQno73dvSLIQQQmRRZyGEECKLHNxCjDltNgOKKTVxtJ0J3ERJ\nSGhT3aV5JWXn2lXPa1NeCW3NgU2UzN4vNQd2RZqFEEKILNIshBgj2obJtp2Q1VW+7cqvw9B1ZdmS\nMtuGGufaUZdvq7l01f5ydfehJdaRZiGEECKLNAshxpBhQjyHtU23DasddGxfdQ4TFtwmL9eulHxb\nv0FJ+U31lDKKpVikWQghhMiizkIIIUQWmaGEGCNKV50ddvOfUYWLphiFWWlUgQBtzVVtN3ZqYyor\nPe/VZZ6UZiGEECKLNAshxojSEM8+97Poug1r17Wk2m4XOowjvW2oatc9JVI0Hdu13GG2VR0WaRZC\nCCGyqLMQQgiRRWYoIcaIPrYJHbUTu8S0M6wzdVBZJaaXXHBAPa+rSStXflN7Supsms9RKj+oHV2Q\nZiGEECKLNAshxoi+V1Hti7ahpE3HDuMs7xqG23VEXqqlNJXV1Sk9jMYzCqRZCCGEyCLNQogxp8Qv\nkdo/ou3Wo11Xck21dXWNivvYn6LUp1CiKQxTfltK6uxTy5RmIYQQIos6CyGEEFlkhhJihtDW2dp1\nWetSJ23bMlL/d924p/S4rmafrmWNOuBgmHW1tDaUEEKIkSPNQogZQlvHbdsJe20nurUdRXfdenRQ\nO0ry2q6f1LQuVdO168Oh33bzprahvMMizUIIIUQWdRZCCCGyyAwlxBiSMx+UzGROmSzaziOo5+Xo\nY6/oev3DLJ3edsZ3m7pTxw3jXG+75lbXmfFdkWYhhBAii62udUWEEELMXKRZCCGEyKLOQgghRBZ1\nFkIIIbKosxBCCJFFnYUQQogs6iyEEEJkUWchhBAiizoLIYQQWdRZCCGEyKLOQgghRBZ1FkIIIbKo\nsxBCCJFFnYUQQogs6iyEEEJkUWchhBAiizoLIYQQWdRZCCGEyKLOQgghRBZ1FkIIIbKosxBCCJFF\nnYUQQogs6iyEEEJkUWchhBAiizoLIYQQWdRZCCGEyKLOQgghRBZ1FkIIIbKosxBCCJFFnYUQQogs\n6iyEEEJkUWchhBAiizoLIYQQWdRZCCGEyKLOQgghRBZ1FkIIIbKosxBCCJFFnYUQQogs6iyEEEJk\n+R9HuLwlzlJD7gAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MR4jlBojGdr1",
+ "colab_type": "text"
+ },
+ "source": [
+ "## ======== Test code ======="
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "HNzpNC4EzkA6",
+ "colab_type": "code",
+ "outputId": "4c161478-3e67-478b-b2f8-e077615cfb52",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 52
+ }
+ },
+ "source": [
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "xs, ys = next(data_iter)\n",
+ "x = xs[0].reshape(1,-1).cuda()\n",
+ "y = ys[0].reshape(1,-1).cuda()\n",
+ "mu, sigma = vae.encoder.forward(x,vae.remap_y(y))\n",
+ "mu = mu.cpu()\n",
+ "sigma = sigma.cpu()\n",
+ "recon_x1, y1, z1 = OldSCM(vae, mu, sigma)\n",
+ "print(y1)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "tensor([ 1.9037e-04, 9.9998e-01, 4.9999e+00, -1.2039e-04, 2.1000e+01,\n",
+ " 4.0000e+00])\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "sLOw_NXQziDa",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def OldSCM(vae, mu, sigma):\n",
+ " z_dim = vae.z_dim\n",
+ " Ny, Y, ys = [], [], []\n",
+ " Nx = pyro.sample(\"Nx\", dist.Uniform(torch.zeros(vae.image_dim), torch.ones(vae.image_dim)))\n",
+ " Nz = pyro.sample(\"Nz\", dist.Normal(torch.zeros(z_dim), torch.ones(z_dim)))\n",
+ " m = torch.distributions.gumbel.Gumbel(torch.tensor(0.0), torch.tensor(1.0))\n",
+ " for label_id in range(6):\n",
+ " name = vae.label_names[label_id]\n",
+ " length = vae.label_shape[label_id]\n",
+ " new = pyro.sample(\"Ny_%s\"%name, dist.Uniform(torch.zeros(length), torch.ones(length)) )\n",
+ " Ny.append(new)\n",
+ " gumbel_vars = torch.tensor([m.sample() for _ in range(length)])\n",
+ " max_ind = torch.argmax(torch.log(new) + gumbel_vars).item()\n",
+ " Y.append(pyro.sample(\"Y_%s\"%name, dist.Normal(torch.tensor(max_ind * 1.0), 1e-4)))\n",
+ "# Y.append(pyro.sample(\"Y_%s\"%name, dist.Delta(torch.tensor(max_ind*1.0))))\n",
+ " ys.append(torch.nn.functional.one_hot(torch.tensor(max_ind), int(length))) \n",
+ " Y = torch.tensor(Y)\n",
+ " ys = torch.cat(ys).to(torch.float32).reshape(1,-1).cuda()\n",
+ " Z = pyro.sample(\"Z\", dist.Normal(mu + Nz*sigma, 1e-4))\n",
+ "# Z = pyro.sample(\"Z\", dist.Delta(mu + Nz*sigma))\n",
+ " zs = Z.cuda()\n",
+ " p = vae.decoder.forward(zs,ys)\n",
+ " X = pyro.sample(\"X\", dist.Normal((Nx < p.cpu()).type(torch.float), 1e-4))\n",
+ "# X = pyro.sample(\"X\", dist.Delta((Nx < p.cpu()).type(torch.float)))\n",
+ " return X, Y, Z"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_0EX3FetiTi-",
+ "colab_type": "code",
+ "outputId": "3007827d-15f3-4f55-f959-081e2122f249",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 255
+ }
+ },
+ "source": [
+ "from pyro.infer.importance import Importance\n",
+ "from pyro.infer.mcmc import MCMC\n",
+ "from pyro.infer.mcmc.nuts import HMC\n",
+ "from pyro.infer import SVI\n",
+ "\n",
+ "intervened_model = pyro.do(OldSCM, data={\"Y_shape\": torch.tensor(0.)})\n",
+ "conditioned_model = pyro.condition(OldSCM, data={\n",
+ " \"X\": recon_x1,\n",
+ "# \"Z\": z1,\n",
+ " \"Y_color\":y1[0],\n",
+ " \"Y_shape\": y1[1],\n",
+ " \"Y_scale\": y1[2],\n",
+ " \"Y_orientation\": y1[3],\n",
+ " \"Y_posX\": y1[4],\n",
+ " \"Y_posY\": y1[5]\n",
+ " })\n",
+ "\n",
+ "# change to svi\n",
+ "posterior = pyro.infer.Importance(conditioned_model, num_samples = 100)\n",
+ "posterior.run(vae, mu, sigma)\n",
+ "\n",
+ "result = []\n",
+ "for i in range(500):\n",
+ " trace = posterior()\n",
+ " x = trace.nodes['Nx']['value']\n",
+ " ny_shape = trace.nodes['Ny_shape']['value']\n",
+ " ny_scale = trace.nodes['Ny_scale']['value']\n",
+ " ny_orientation = trace.nodes['Ny_orientation']['value']\n",
+ " ny_posX = trace.nodes['Ny_posX']['value']\n",
+ " ny_posY = trace.nodes['Ny_posY']['value']\n",
+ " z = trace.nodes['Nz']['value']\n",
+ " con_obj = pyro.condition(intervened_model, data = {\n",
+ " \"Nx\": x,\n",
+ " \"Ny_shape\": ny_shape, \n",
+ " \"Ny_scale\": ny_scale, \n",
+ " \"Ny_orientation\": ny_orientation, \n",
+ " \"Ny_posX\": ny_posX, \n",
+ " \"Ny_posY\": ny_posY, \n",
+ " \"Nz\": z\n",
+ " })\n",
+ " \n",
+ "recon_x2,y2,z2 = con_obj(vae, mu, sigma)\n",
+ "print(y2)\n",
+ "recon_check(recon_x1.reshape(-1, 64, 64)[0], recon_x2.reshape(-1, 64, 64)[0])"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "tensor([-9.3067e-05, 0.0000e+00, 2.0000e+00, 3.9999e+00, 2.1000e+01,\n",
+ " 1.6000e+01])\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAADMCAYAAAB+36QhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAABA9JREFUeJzt3W1O20AARdFOxf63PP0VNYqitHw4\n4/E9ZwFgJND1GxsYc85fAHT9Xn0BAKwlBABxQgAQJwQAcUIAECcEAHFCABAnBABxQgAQJwQAcUIA\nECcEAHEfqy/gK8YY/lIeh5pzjhWf1/c2R3v2vW0RAMQJAUCcEADECQFAnBAAxAkBQJwQAMQJAUCc\nEADECQFAnBAAxAkBQJwQAMQJAUCcEADECQFAnBAAxAkBQJwQAMQJAUCcEADECQFAnBAAxAkBQJwQ\nAMQJAUCcEADECQFAnBAAxH2svgDea8556McfYxz68YGfZxEAxFkEF3T0Xf9XP7e1AOdkEQDECQFA\nnKOhza08Bvqs27U6IoJzsQgA4iyCTe20BB49u3YrAdaxCADihAAgztHQRnY+DvqX+6/NMRG8l0UA\nEGcRbODKSwBYzyIAiBMCTmfOaQXBGwkBQJwQAMR5WMzpeH0U3ssiAIgTAoA4IQCIEwKAOCHYwBjD\nA1TgMEIAEOf1UU7B4uHd/OvUvywCgDiLYCPP7lx2/5s87sZ4p2c/L/4XhkUAkCcEAHGOhjZ3m7I7\nHRFV5zfr/O/PR/UBskUAEGcRXMSrO5gVa6F2R8V5/MT3e20ZWAQAcUIAEOdoKOA787Y2kdnXTi9M\nnI1FABBnEfCSJcCZHb0CKr91bBEAxFkEwHZWPA+48vMyiwAgTggA4hwNAad2ttdCr/gA2SIAiBMC\ngDghAIjzjADgix6fX+z6zMAiAIgTAoA4R0PAqd0ft5ztVdKrsAgA4oQA2MYYY9sHsmcmBABxQgAQ\n52ExsJ3b8dDqh8dXOaayCADiLAKAT7jKCrhnEQDEWQTAto7+ZbMr3v0/YxEAxAkBQJyjIeASfvKV\n0sqR0I1FABBnEQCX8p0HyLUlcGMRAMQJAUCcoyEgrXocdM8iAIizCIDLenyl1N3/cxYBQJxFAFye\nJfCaRQAQJwQAcUIAECcEAHFCABAnBABxQgAQJwQAcUIAECcEAHFCABAnBABxQgAQJwQAcUIAECcE\nAHFCABAnBABxQgAQJwQAcUIAECcEAHFCABAnBABxQgAQJwQAcUIAECcEAHFCABAnBABxQgAQJwQA\ncUIAECcEAHFCABAnBABxQgAQJwQAcUIAECcEAHFCABAnBABxQgAQJwQAcUIAECcEAHFCABAnBABx\nQgAQJwQAcUIAECcEAHFjzrn6GgBYyCIAiBMCgDghAIgTAoA4IQCIEwKAOCEAiBMCgDghAIgTAoA4\nIQCIEwKAOCEAiBMCgDghAIgTAoA4IQCIEwKAOCEAiBMCgDghAIgTAoA4IQCIEwKAOCEAiBMCgDgh\nAIgTAoA4IQCIEwKAOCEAiBMCgDghAIgTAoA4IQCIEwKAOCEAiBMCgLg/wmR+1GplMnEAAAAASUVO\nRK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Y1MuTqXb2in_",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "result = []\n",
+ "for i in range(500):\n",
+ " trace = posterior()\n",
+ " x = trace.nodes['Nx']['value']\n",
+ " ny_shape = trace.nodes['Ny_shape']['value']\n",
+ " ny_scale = trace.nodes['Ny_scale']['value']\n",
+ " ny_orientation = trace.nodes['Ny_orientation']['value']\n",
+ " ny_posX = trace.nodes['Ny_posX']['value']\n",
+ " ny_posY = trace.nodes['Ny_posY']['value']\n",
+ " z = trace.nodes['Nz']['value']\n",
+ " con_obj = pyro.condition(intervened_model, data = {\n",
+ " \"Nx\": x,\n",
+ " \"Ny_shape\": ny_shape, \n",
+ " \"Ny_scale\": ny_scale, \n",
+ " \"Ny_orientation\": ny_orientation, \n",
+ " \"Ny_posX\": ny_posX, \n",
+ " \"Ny_posY\": ny_posY, \n",
+ " \"Nz\": z\n",
+ " })\n",
+ " \n",
+ "recon_x2,y2,z2 = con_obj(vae, mu, sigma)\n",
+ "print(y2)\n",
+ "recon_check(recon_x1.reshape(-1, 64, 64)[0], recon_x2.reshape(-1, 64, 64)[0])"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "cou0AUQhejcz",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "vae.label_names"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "H7A27SBfl8aM",
+ "colab_type": "code",
+ "outputId": "feb97a2d-c831-436b-fab8-aa3f49052e4e",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 390
+ }
+ },
+ "source": [
+ "adam_params = {\"lr\": 0.00042, \"betas\": (0.9, 0.999)}\n",
+ "optimizer = Adam(adam_params)\n",
+ "# set up the loss(es) for inference. wrapping the guide in config_enumerate builds\n",
+ "# the loss as a sum\n",
+ "# by enumerating each class label for the sampled discrete categorical distribution\n",
+ "# in the model\n",
+ "guide = config_enumerate(guide, \"parallel\", expand=True)\n",
+ "elbo = Trace_ELBO(max_plate_nesting=1)\n",
+ "loss_basic = SVI(sup_vae.model, guide, optimizer, loss=elbo)\n",
+ "\n",
+ "# build a list of all losses considered\n",
+ "losses = [loss_basic]\n",
+ "loss_aux = SVI(sup_vae.model_classify, sup_vae.guide_classify, optimizer, loss=elbo)\n",
+ "losses.append(loss_aux)\n",
+ "\n",
+ "\n",
+ "\n",
+ "from pyro.optim import Adam\n",
+ "from pyro.infer import SVI, Trace_ELBO\n",
+ "\n",
+ "adam = Adam({\"lr\": 0.005, \"betas\": (0.90, 0.999)})\n",
+ "svi = SVI(vae.model, vae.guide, adam, loss=Trace_ELBO())\n",
+ "\n",
+ "\n",
+ "param_vals = []\n",
+ "pbar = tqdm(range(5))\n",
+ "for _ in pbar:\n",
+ " xs, ys = next(data_iter)\n",
+ " xs = xs.cuda()\n",
+ " ys = ys.cuda()\n",
+ " svi.step(xs, ys)\n",
+ "# param_vals.append({k: param(k).item() for k in [\"fl\", \"ia\"]})\n",
+ "\n",
+ "# pd.DataFrame(param_vals).plot(subplots=True)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " 0%| | 0/5 [00:00, ?it/s]\u001b[A"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "error",
+ "ename": "StopIteration",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mStopIteration\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mpbar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mxs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0mxs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 792\u001b[0m \u001b[0;31m# no valid `self.rcvd_idx` is found (i.e., didn't break)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 793\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_shutdown_workers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 794\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 795\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 796\u001b[0m \u001b[0;31m# Now `self.rcvd_idx` is the batch index we want to fetch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mStopIteration\u001b[0m: "
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "pa3ygX4zCYdH",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/causal_vae_dsprites.py b/causal_vae_dsprites.py
new file mode 100644
index 0000000..13f2044
--- /dev/null
+++ b/causal_vae_dsprites.py
@@ -0,0 +1,972 @@
+# -*- coding: utf-8 -*-
+"""causal_vae_dsprites.ipynb
+
+Automatically generated by Colaboratory.
+
+Original file is located at
+ https://colab.research.google.com/drive/17TusqgELDORkovcWZxzlyScvCa7CgJyP
+
+## Deep Causal Variational Inference
+
+### Introduction:
+To train a supervised variational autoencoder using Deepmind's [dSprites](https://github.com/deepmind/dsprites-dataset) dataset.
+
+dSprites is a dataset of sprites, which are 2D shapes procedurally generated from 5 ground truth independent "factors." These factors are color, shape, scale, rotation, x and y positions of a sprite.
+
+All possible combinations of these variables are present exactly once, generating N = 737280 total images.
+
+Factors and their values:
+
+* Shape: 3 values {square, ellipse, heart}
+* Scale: 6 values linearly spaced in (0.5, 1)
+* Orientation: 40 values in (0, 2$\pi$)
+* Position X: 32 values in (0, 1)
+* Position Y: 32 values in (0, 1)
+
+
+Further, the objective of any generative model is essentially to capture underlying data generative factors, the disentangled representation would mean a single latent unit being sensitive to variations in single generative factors
+
+
+### Goal:
+To include the latent factors as labels in the training and to invent a causal story that relates these factors and the images in a DAG.
+
+Reference
+
+[Structured Disentangled Representation](https://arxiv.org/pdf/1804.02086.pdf)
+"""
+
+#Install dependencies
+!pip3 install pyro-ppl
+!pip3 install torch torchvision
+!pip3 install pydrive --upgrade
+!pip3 install tqdm
+
+# Commented out IPython magic to ensure Python compatibility.
+# Load necessary libraries
+from matplotlib import pyplot as plt
+import numpy as np
+import seaborn as sns
+
+import os
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+
+from tqdm import tqdm
+import pyro
+import pyro.distributions as dist
+from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, EmpiricalMarginal
+from pyro.optim import Adam, SGD
+import torch.distributions.constraints as constraints
+
+# Change figure aesthetics
+# %matplotlib inline
+sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})
+
+from ipywidgets import interact, interactive, fixed, interact_manual
+import ipywidgets as widgets
+
+#to utilize GPU capabilities
+USE_CUDA = True
+
+pyro.enable_validation(True)
+pyro.distributions.enable_validation(False)
+
+# Mount Google drive to load data
+from google.colab import drive
+drive.mount('/content/gdrive')
+
+# Mount G drive to access files
+from pydrive.auth import GoogleAuth
+from pydrive.drive import GoogleDrive
+from google.colab import auth
+from oauth2client.client import GoogleCredentials
+
+auth.authenticate_user()
+gauth = GoogleAuth()
+gauth.credentials = GoogleCredentials.get_application_default()
+drive = GoogleDrive(gauth)
+
+# Hack to get all available GPU ram.
+
+import tensorflow as tf
+tf.test.gpu_device_name()
+
+!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
+!pip install gputil
+!pip install psutil
+!pip install humanize
+import psutil
+import humanize
+import os
+import GPUtil as GPU
+GPUs = GPU.getGPUs()
+# XXX: only one GPU on Colab and isn’t guaranteed
+gpu = GPUs[0]
+def printm():
+ process = psutil.Process(os.getpid())
+ print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
+ print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
+printm()
+
+class Encoder(nn.Module):
+ """
+ MLPs (multi-layered perceptrons or simple feed-forward networks)
+ where the provided activation parameter is used on every linear layer except
+ for the output layer where we use the provided output_activation parameter
+ """
+ def __init__(self, image_dim, label_dim, z_dim):
+ super(Encoder, self).__init__()
+ #setup image and label dimensions from the dataset
+ self.image_dim = image_dim
+ self.label_dim = label_dim
+ self.z_dim = z_dim
+ # setup the three linear transformations used
+ self.fc1 = nn.Linear(self.image_dim+self.label_dim, 1000)
+ self.fc2 = nn.Linear(1000, 1000)
+ self.fc31 = nn.Linear(1000, z_dim) # mu values
+ self.fc32 = nn.Linear(1000, z_dim) # sigma values
+ # setup the non-linearities
+ self.softplus = nn.Softplus()
+
+ def forward(self, xs, ys):
+ xs = xs.reshape(-1, self.image_dim)
+ #now concatenate the image and label
+ inputs = torch.cat((xs,ys), -1)
+ # then compute the hidden units
+ hidden1 = self.softplus(self.fc1(inputs))
+ hidden2 = self.softplus(self.fc2(hidden1))
+ # then return a mean vector and a (positive) square root covariance
+ # each of size batch_size x z_dim
+ z_loc = self.fc31(hidden2)
+ z_scale = torch.exp(self.fc32(hidden2))
+ return z_loc, z_scale
+
+
+class Decoder(nn.Module):
+ def __init__(self, image_dim, label_dim, z_dim):
+ super(Decoder, self).__init__()
+ # setup the two linear transformations used
+ hidden_dim = 1000
+ self.fc1 = nn.Linear(z_dim+label_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
+ self.fc3 = nn.Linear(hidden_dim, hidden_dim)
+ self.fc4 = nn.Linear(hidden_dim, image_dim)
+ # setup the non-linearities
+ self.softplus = nn.Softplus()
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, zs, ys):
+ inputs = torch.cat((zs, ys),-1)
+ # then compute the hidden units
+ hidden1 = self.softplus(self.fc1(inputs))
+ hidden2 = self.softplus(self.fc2(hidden1))
+ hidden3 = self.softplus(self.fc3(hidden2))
+ # return the parameter for the output Bernoulli
+ # each is of size batch_size x 784
+ loc_img = self.sigmoid(self.fc4(hidden3))
+ return loc_img
+
+class CVAE(nn.Module):
+ """
+ This class encapsulates the parameters (neural networks) and models & guides
+ needed to train a supervised variational auto-encoder
+ """
+ def __init__(self, config_enum=None, use_cuda=False, aux_loss_multiplier=None):
+
+ super(CVAE, self).__init__()
+ self.image_dim = 64**2
+ self.label_shape = np.array((1,3,6,40,32,32))
+ self.label_names = np.array(('color', 'shape', 'scale', 'orientation', 'posX', 'posY'))
+ self.label_dim = np.sum(self.label_shape)
+ self.z_dim = 50
+ self.allow_broadcast = config_enum == 'parallel'
+ self.use_cuda = use_cuda
+ self.aux_loss_multiplier = aux_loss_multiplier
+ # define and instantiate the neural networks representing
+ # the paramters of various distributions in the model
+ self.setup_networks()
+
+ def setup_networks(self):
+ """
+ Setup and initialize Encoder and decoder units
+ """
+ self.encoder = Encoder(self.image_dim, self.label_dim, self.z_dim)
+ self.decoder = Decoder(self.image_dim, self.label_dim, self.z_dim)
+ # using GPUs for faster training of the networks
+ if self.use_cuda:
+ self.cuda()
+
+ def model(self, xs, ys):
+ pyro.module("cvae", self)
+ batch_size = xs.size(0)
+ options = dict(dtype=xs.dtype, device=xs.device)
+ with pyro.plate("data"):
+ prior_loc = torch.zeros(batch_size, self.z_dim, **options)
+ prior_scale = torch.ones(batch_size, self.z_dim, **options)
+ zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))
+ # if the label y (which digit to write) is supervised, sample from the
+ # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
+ loc = self.decoder.forward(zs, self.remap_y(ys))
+ pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
+ # return the loc so we can visualize it later
+ return loc
+
+ def guide(self, xs, ys):
+ with pyro.plate("data"):
+ # sample (and score) the latent handwriting-style with the variational
+ # distribution q(z|x) = normal(loc(x),scale(x))
+ loc, scale = self.encoder.forward(xs, self.remap_y(ys))
+ pyro.sample("z", dist.Normal(loc, scale).to_event(1))
+
+ def remap_y(self, ys):
+ new_ys = []
+ options = dict(dtype=ys.dtype, device=ys.device)
+ for i, label_length in enumerate(self.label_shape):
+ prior = torch.ones(ys.size(0), label_length, **options) / (1.0 * label_length)
+ new_ys.append(pyro.sample("y_%s" % self.label_names[i], dist.OneHotCategorical(prior),
+ obs=torch.nn.functional.one_hot(ys[:,i].to(torch.int64), int(label_length))))
+ new_ys = torch.cat(new_ys, -1)
+ return new_ys.to(torch.float32)
+
+ def reconstruct_image(self, xs, ys):
+ # backward
+ sim_z_loc, sim_z_scale = self.encoder.forward(xs, self.remap_y(ys))
+ zs = dist.Normal(sim_z_loc, sim_z_scale).to_event(1).sample()
+ # forward
+ loc = self.decoder.forward(zs, self.remap_y(ys))
+ return dist.Bernoulli(loc).to_event(1).sample()
+
+def setup_data_loaders(train_x, test_x, train_y, test_y, batch_size=128, use_cuda=False):
+ train_dset = torch.utils.data.TensorDataset(
+ torch.from_numpy(train_x.astype(np.float32)).reshape(-1, 4096),
+ torch.from_numpy(train_y.astype(np.float32))
+ )
+ test_dset = torch.utils.data.TensorDataset(
+ torch.from_numpy(test_x.astype(np.float32)).reshape(-1, 4096),
+ torch.from_numpy(test_y.astype(np.float32))
+ )
+ kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
+ train_loader = torch.utils.data.DataLoader(
+ dataset=train_dset, batch_size=batch_size, shuffle=False, **kwargs
+ )
+
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dset, batch_size=batch_size, shuffle=False, **kwargs
+ )
+ return {"train":train_loader, "test":test_loader}
+
+dataset_zip = np.load(
+ '/content/gdrive/My Drive/data-science/causal-ml/projects/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
+ encoding = 'bytes',
+ allow_pickle=True
+)
+
+imgs = dataset_zip['imgs']
+labels = dataset_zip['latents_classes']
+label_sizes = dataset_zip['metadata'][()][b'latents_sizes']
+label_names = dataset_zip['metadata'][()][b'latents_names']
+
+# Sample imgs randomly
+indices_sampled = np.arange(imgs.shape[0])
+np.random.shuffle(indices_sampled)
+imgs_sampled = imgs[indices_sampled]
+labels_sampled = labels[indices_sampled]
+
+data_loaders = setup_data_loaders(
+ imgs_sampled[1000:],
+ imgs_sampled[:1000],
+ labels_sampled[1000:],
+ labels_sampled[:1000],
+ batch_size=256,
+ use_cuda=USE_CUDA
+)
+
+def train(svi, train_loader, use_cuda=False):
+ # initialize loss accumulator
+ epoch_loss = 0.
+ # do a training epoch over each mini-batch x returned
+ # by the data loader
+ for xs,ys in train_loader:
+ # if on GPU put mini-batch into CUDA memory
+ if use_cuda:
+ xs = xs.cuda()
+ ys = ys.cuda()
+ # do ELBO gradient and accumulate loss
+ epoch_loss += svi.step(xs, ys)
+ # return epoch loss
+ normalizer_train = len(train_loader.dataset)
+ total_epoch_loss_train = epoch_loss / normalizer_train
+ return total_epoch_loss_train
+
+def evaluate(svi, test_loader, use_cuda=False):
+ # initialize loss accumulator
+ test_loss = 0.
+ # compute the loss over the entire test set
+ for xs, ys in test_loader:
+ # if on GPU put mini-batch into CUDA memory
+ if use_cuda:
+ xs = xs.cuda()
+ ys = ys.cuda()
+ # compute ELBO estimate and accumulate loss
+ test_loss += svi.evaluate_loss(xs, ys)
+ normalizer_test = len(test_loader.dataset)
+ total_epoch_loss_test = test_loss / normalizer_test
+ return total_epoch_loss_test
+
+# Run options
+LEARNING_RATE = 1.0e-3
+
+# Run only for a single iteration for testing
+NUM_EPOCHS = 10
+TEST_FREQUENCY = 5
+
+#################################
+### FOR SAVING AND LOADING MODEL
+################################
+# clear param store
+
+pyro.clear_param_store()
+
+network_path = "/content/gdrive/My Drive/data-science/causal-ml/projects/trained_model.save"
+
+PATH = "trained_model.save"
+
+# new model
+# vae = CVAE(use_cuda=USE_CUDA)
+
+# save current model
+# torch.save(vae.state_dict(), PATH)
+
+# to load params from trained model
+vae = CVAE(use_cuda=USE_CUDA)
+vae.load_state_dict(torch.load(network_path))
+
+"""### **DONT RUN THE BELOW CODE AS WE'VE ALREADY TRAINED THE MODEL AND WE'VE STORED THE NETWORK PARAMS**
+
+## ==================================================================================
+"""
+
+import warnings
+warnings.filterwarnings('ignore')
+
+# clear param store
+pyro.clear_param_store()
+
+# setup the VAE
+vae = CVAE(use_cuda=USE_CUDA)
+
+# setup the optimizer
+adam_args = {"lr": LEARNING_RATE}
+optimizer = Adam(adam_args)
+
+# setup the inference algorithm
+svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
+
+train_elbo = []
+test_elbo = []
+# training loop
+
+VERBOSE = True
+pbar = tqdm(range(NUM_EPOCHS))
+for epoch in pbar:
+ total_epoch_loss_train = train(svi, data_loaders["train"], use_cuda=USE_CUDA)
+ train_elbo.append(-total_epoch_loss_train)
+ if VERBOSE:
+ print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))
+ if epoch % TEST_FREQUENCY == 0:
+ # report test diagnostics
+ total_epoch_loss_test = evaluate(svi, data_loaders["test"], use_cuda=USE_CUDA)
+ test_elbo.append(-total_epoch_loss_test)
+ if VERBOSE:
+ print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
+
+"""## ==================================================================================
+
+## Visualizing the reconstruction accuracy of VAE
+"""
+
+
+
+import warnings
+warnings.filterwarnings("ignore")
+
+data_iter = iter(data_loaders["train"])
+xs, ys = next(data_iter)
+
+if USE_CUDA:
+ xs = xs.cuda()
+ ys = ys.cuda()
+rs = vae.reconstruct_image(xs, ys)
+if USE_CUDA:
+ xs = xs.cpu()
+ rs = rs.cpu()
+originals = xs.numpy().reshape(-1, 64,64)
+recons = rs.reshape(-1,64,64)
+
+# [ 0, 2, 1, 34, 4, 24]
+def recon_check(original, recon):
+ fig = plt.figure()
+ ax0 = fig.add_subplot(121)
+ plt.imshow(original, cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+ ax1 = fig.add_subplot(122)
+ plt.imshow(recon , cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+
+def f(x):
+ fig = plt.figure()
+ ax0 = fig.add_subplot(121)
+ plt.imshow(originals[x], cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+ ax1 = fig.add_subplot(122)
+ plt.imshow(recons[x], cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+
+interact(f, x=widgets.IntSlider(min=0, max=xs.shape[0], step=1, value=0))
+
+y_names = ['shape', 'scale', 'orientation', 'posX', 'posY']
+y_shapes = np.array((3,6,40,32,32))
+img_dict = {}
+
+for i, img in enumerate(imgs_sampled):
+ img_dict[tuple(labels_sampled[i])] = img
+
+def find_in_dataset(shape, scale, orient, posX, posY):
+ fig = plt.figure()
+ img = img_dict[(0, shape, scale, orient, posX, posY)]
+ plt.imshow(img.reshape(64,64), cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+
+interact(find_in_dataset,
+ shape=widgets.IntSlider(min=0, max=2, step=1, value=0),
+ scale=widgets.IntSlider(min=0, max=5, step=1, value=0),
+ orient=widgets.IntSlider(min=0, max=39, step=1, value=0),
+ posX=widgets.IntSlider(min=0, max=31, step=1, value=0),
+ posY=widgets.IntSlider(min=0, max=31, step=1, value=0))
+
+def get_specific_data(args=dict(), cuda=False):
+ '''
+ use this function to get examples of data with specific class labels
+ inputs:
+ args - dictionary whose keys can include {shape, scale, orientation,
+ posX, posY} and values can include any integers less than the
+ corresponding size of that label dimension
+ cuda - bool to indicate whether the output should be placed on GPU
+ '''
+ names_dict = {'shape': 1, 'scale': 2, 'orientation': 3, 'posX': 4, 'posY': 5}
+ selected_ind = np.ones(imgs.shape[0], dtype=bool)
+ for k,v in args.items():
+ col_id = names_dict[k]
+ selected_ind = np.bitwise_and(selected_ind, labels[:, col_id] == v)
+ ind = np.random.choice(np.arange(imgs.shape[0])[selected_ind])
+ x = torch.from_numpy(imgs[ind].reshape(1,64**2).astype(np.float32))
+ y = torch.from_numpy(labels[ind].reshape(1,6).astype(np.float32))
+ if not cuda:
+ return x,y
+ x = x.cuda()
+ y = y.cuda()
+ return x,y
+
+def plot_image(x):
+ """
+ helper to plot dSprites images
+ """
+ x = x.cpu()
+ plt.figure()
+ plt.imshow(x.reshape(64,64), interpolation='nearest', cmap='Greys_r')
+ plt.axis('off')
+
+def see_specific_image(args=dict(), verbose=True):
+ '''
+ use this function to get examples of data with specific class labels
+ inputs:
+ args - dictionary whose keys can include {shape, scale, orientation,
+ posX, posY} and values can include any integers less than the
+ corresponding size of that label dimension
+ verbose - bool to indicate whether the full class label should be written
+ as the title of the plot
+ '''
+ x,y = get_specific_data(args, cuda=False)
+ plot_image(x)
+ if verbose:
+ string = ''
+ for i, s in enumerate(['Shape', 'Scale', 'Orientation', 'PosX', 'PosY']):
+ string += '%s: %d, ' % (s, int(y[0][i+1]))
+ if i == 2:
+ string = string[:-2] + '\n'
+ plt.title(string[:-2], fontsize=12)
+
+def compare_reconstruction(original, recon):
+ """
+ compare two images side by side
+ inputs:
+ original - array for original image
+ recon - array for recon image
+ """
+ fig = plt.figure()
+ ax0 = fig.add_subplot(121)
+ plt.imshow(original.cpu().reshape(64,64), cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+ plt.title('original')
+ ax1 = fig.add_subplot(122)
+ plt.imshow(recon.cpu().reshape(64,64), cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+ plt.title('reconstruction')
+
+def compare_to_density(original, recons):
+ """
+ compare two images side by side
+ inputs:
+ original - array for original image
+ recon - array of multiple recon images
+ """
+ fig = plt.figure()
+ ax0 = fig.add_subplot(121)
+ plt.imshow(original.cpu().reshape(64,64), cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+ plt.title('original')
+ ax1 = fig.add_subplot(122)
+ plt.imshow(torch.mean(recons.cpu(), 0).reshape(64,64), cmap='Greys_r', interpolation='nearest')
+ plt.axis('off')
+ plt.title('reconstructions')
+
+
+see_specific_image()
+
+label_dims = vae.label_shape
+label_dim_offsets = np.cumsum(label_dims)
+label_dim_offsets
+
+class SCM():
+ """
+ Structural causal model
+
+ args:
+ vae: instance of vae
+ mu: loc of q(z|x) given by the vae encoder
+ sigma: scale of q(z|x) given by the vae encoder
+
+ """
+ def __init__(self, vae, mu, sigma):
+ """
+ Constructor
+
+ Intializes :
+ image dimensions - 4096(64*64),
+ z dimensions: size of the tensor representing the latent random variable z,
+ label dimensions: 114 labels y that correspond to an image(one hot encoded)
+ f(x) = p(x|y,z)
+ Noise variables in the model N_#
+ """
+ self.vae = vae
+ self.image_dim = vae.image_dim
+ self.z_dim = vae.z_dim
+ # these are used for f_X
+ self.label_dims = vae.label_shape
+
+ def f_X(Y, Z, N):
+ """
+ Generating one hots for the factors
+ """
+ zs = Z.cuda()
+ # convert the labels to one hot
+ ys = [torch.tensor([0])]
+ ys.append(torch.nn.functional.one_hot(torch.round(Y[0]).to(torch.long), int(self.label_dims[1])))
+ ys.append(torch.nn.functional.one_hot(torch.round(Y[1]).to(torch.long), int(self.label_dims[2])))
+ ys.append(torch.nn.functional.one_hot(torch.round(Y[2]).to(torch.long), int(self.label_dims[3])))
+ ys.append(torch.nn.functional.one_hot(torch.round(Y[3]).to(torch.long), int(self.label_dims[4])))
+ ys.append(torch.nn.functional.one_hot(torch.round(Y[4]).to(torch.long), int(self.label_dims[5])))
+ ys = torch.cat(ys).to(torch.float32).reshape(1,-1).cuda()
+ p = vae.decoder.forward(zs, ys)
+ return (N < p.cpu()).type(torch.float)
+
+ def f_Y(N):
+ """
+ Gumbel distribution - to model the distribution of the maximum of a number of samples
+ m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])).sample() # sample from Gumbel distribution with loc=1, scale=2
+ tensor([ 1.0124])
+
+ https://pytorch.org/docs/stable/_modules/torch/distributions/gumbel.html
+ """
+# m = torch.distributions.gumbel.Gumbel(torch.zeros(N.size(0)), torch.ones(N.size(0)))
+ beta = 12
+ indices = torch.tensor(np.arange(N.size(0))).to(torch.float32)
+ smax = nn.functional.softmax(beta*N)
+ argmax_ind = torch.sum(smax*indices)
+ return argmax_ind
+
+ def f_Z(N):
+ """
+ Z ~ Normal(mu, sigma)
+ """
+ return N * sigma + mu
+
+ def model(noise):
+ """
+ The model corresponds to a generative process
+
+ args: noise variables
+ return: X(image), Y(labels), Z(latents)
+ """
+ N_X = pyro.sample( 'N_X', noise['N_X'].to_event(1) )
+ # denoted using the index in the sequence
+ # that they are stored in as vae.label_names:
+ # ['shape', 'scale', 'orientation', 'posX', 'posY']
+ N_Y_1 = pyro.sample( 'N_Y_1', noise['N_Y_1'].to_event(1) )
+ N_Y_2 = pyro.sample( 'N_Y_2', noise['N_Y_2'].to_event(1) )
+ N_Y_3 = pyro.sample( 'N_Y_3', noise['N_Y_3'].to_event(1) )
+ N_Y_4 = pyro.sample( 'N_Y_4', noise['N_Y_4'].to_event(1) )
+ N_Y_5 = pyro.sample( 'N_Y_5', noise['N_Y_5'].to_event(1) )
+
+ # Z ~ Normal(Nx_mu, Nx_sigma)
+ N_Z = pyro.sample( 'N_Z', noise['N_Z'].to_event(1) )
+ Z = pyro.sample('Z', dist.Normal( f_Z( N_Z ), 1e-1).to_event(1) )
+
+ # Y ~ Gumbel max of Ny
+# Y_1_mu = f_Y(N_Y_1)
+# Y_2_mu = f_Y(N_Y_2)
+# Y_3_mu = f_Y(N_Y_3)
+# Y_4_mu = f_Y(N_Y_4)
+# Y_5_mu = f_Y(N_Y_5)
+
+ Y_1 = pyro.sample('Y_1', dist.Normal( f_Y(N_Y_1), 1e-2) )
+ Y_2 = pyro.sample('Y_2', dist.Normal( f_Y(N_Y_2), 1e-1) )
+ Y_3 = pyro.sample('Y_3', dist.Normal( f_Y(N_Y_3), 1e-1) )
+ Y_4 = pyro.sample('Y_4', dist.Normal( f_Y(N_Y_4), 1e-1) )
+ Y_5 = pyro.sample('Y_5', dist.Normal( f_Y(N_Y_5), 1e-1) )
+
+# Y_mu = (Y_1_mu, Y_2_mu, Y_3_mu, Y_4_mu, Y_5_mu)
+
+ # X ~ p(x|y,z) = bernoulli(loc(y,z))
+ X = pyro.sample('X', dist.Normal( f_X( (Y_1, Y_2, Y_3,Y_4,Y_5), Z, N_X ), 1e-2).to_event(1))
+
+ # return noise and variables
+ noise_samples = N_X, (N_Y_1, N_Y_2, N_Y_3, N_Y_4, N_Y_5), N_Z
+ variable_samples = X, (Y_1, Y_2, Y_3, Y_4, Y_5), Z
+ return variable_samples, noise_samples
+
+ self.model = model
+ #Initialize all noise variables in the model
+ self.init_noise = {
+ 'N_X' : dist.Uniform(torch.zeros(vae.image_dim), torch.ones(vae.image_dim)),
+ 'N_Z' : dist.Normal(torch.zeros(vae.z_dim), torch.ones(vae.z_dim)),
+ 'N_Y_1' : dist.Uniform(torch.zeros(self.label_dims[1]),torch.ones(self.label_dims[1])),
+ 'N_Y_2' : dist.Uniform(torch.zeros(self.label_dims[2]),torch.ones(self.label_dims[2])),
+ 'N_Y_3' : dist.Uniform(torch.zeros(self.label_dims[3]),torch.ones(self.label_dims[3])),
+ 'N_Y_4' : dist.Uniform(torch.zeros(self.label_dims[4]),torch.ones(self.label_dims[4])),
+ 'N_Y_5' : dist.Uniform(torch.zeros(self.label_dims[5]),torch.ones(self.label_dims[5]))
+ }
+
+ def update_noise_svi(self, obs_data, intervened_model=None):
+ """
+ Use svi to find out the mu, sigma of the distributionsfor the
+ condition outlined in obs_data
+ """
+
+ def guide(noise):
+ """
+ The guide serves as an approximation to the posterior p(z|x).
+ The guide provides a valid joint probability density over all the
+ latent random variables in the model.
+
+ https://pyro.ai/examples/svi_part_i.html
+ """
+ # create params with constraints
+ mu = {
+ 'N_X': pyro.param('N_X_mu', 0.5*torch.ones(self.image_dim),constraint = constraints.interval(0., 1.)),
+ 'N_Z': pyro.param('N_Z_mu', torch.zeros(self.z_dim),constraint = constraints.interval(-3., 3.)),
+ 'N_Y_1': pyro.param('N_Y_1_mu', 0.5*torch.ones(self.label_dims[1]),constraint = constraints.interval(0., 1.)),
+ 'N_Y_2': pyro.param('N_Y_2_mu', 0.5*torch.ones(self.label_dims[2]),constraint = constraints.interval(0., 1.)),
+ 'N_Y_3': pyro.param('N_Y_3_mu', 0.5*torch.ones(self.label_dims[3]),constraint = constraints.interval(0., 1.)),
+ 'N_Y_4': pyro.param('N_Y_4_mu', 0.5*torch.ones(self.label_dims[4]),constraint = constraints.interval(0., 1.)),
+ 'N_Y_5': pyro.param('N_Y_5_mu', 0.5*torch.ones(self.label_dims[5]),constraint = constraints.interval(0., 1.))
+ }
+ sigma = {
+ 'N_X': pyro.param('N_X_sigma', 0.1*torch.ones(self.image_dim),constraint = constraints.interval(0.0001, 0.5)),
+ 'N_Z': pyro.param('N_Z_sigma', torch.ones(self.z_dim),constraint = constraints.interval(0.0001, 3.)),
+ 'N_Y_1': pyro.param('N_Y_1_sigma', 0.1*torch.ones(self.label_dims[1]),constraint = constraints.interval(0.0001, 0.5)),
+ 'N_Y_2': pyro.param('N_Y_2_sigma', 0.1*torch.ones(self.label_dims[2]),constraint = constraints.interval(0.0001, 0.5)),
+ 'N_Y_3': pyro.param('N_Y_3_sigma', 0.1*torch.ones(self.label_dims[3]),constraint = constraints.interval(0.0001, 0.5)),
+ 'N_Y_4': pyro.param('N_Y_4_sigma', 0.1*torch.ones(self.label_dims[4]),constraint = constraints.interval(0.0001, 0.5)),
+ 'N_Y_5': pyro.param('N_Y_5_sigma', 0.1*torch.ones(self.label_dims[5]),constraint = constraints.interval(0.0001, 0.5))
+ }
+ for noise_term in noise.keys():
+ pyro.sample(noise_term, dist.Normal(mu[noise_term], sigma[noise_term]).to_event(1))
+
+ # Condition the model
+ if intervened_model is not None:
+ obs_model = pyro.condition(intervened_model, obs_data)
+ else:
+ obs_model = pyro.condition(self.model, obs_data)
+
+ pyro.clear_param_store()
+
+ # Once we’ve specified a guide, we’re ready to proceed to inference.
+ # Now, this an optimization problem where each iteration of training takes
+ # a step that moves the guide closer to the exact posterior
+
+ # https://arxiv.org/pdf/1601.00670.pdf
+ svi = SVI(
+ model= obs_model,
+ guide= guide,
+ optim= SGD({"lr": 1e-5, 'momentum': 0.1}),
+ loss=Trace_ELBO(retain_graph=True)
+ )
+
+ num_steps = 1500
+ samples = defaultdict(list)
+ for t in range(num_steps):
+ loss = svi.step(self.init_noise)
+# if t % 100 == 0:
+# print("step %d: loss of %.2f" % (t, loss))
+ for noise in self.init_noise.keys():
+ mu = '{}_mu'.format(noise)
+ sigma = '{}_sigma'.format(noise)
+ samples[mu].append(pyro.param(mu).detach().numpy())
+ samples[sigma].append(pyro.param(sigma).detach().numpy())
+ means = {k: torch.tensor(np.array(v).mean(axis=0)) for k, v in samples.items()}
+
+ # update the inferred noise
+ updated_noise = {
+ 'N_X' : dist.Normal(means['N_X_mu'], means['N_X_sigma']),
+ 'N_Z' : dist.Normal(means['N_Z_mu'], means['N_Z_sigma']),
+ 'N_Y_1': dist.Normal(means['N_Y_1_mu'], means['N_Y_1_sigma']),
+ 'N_Y_2': dist.Normal(means['N_Y_2_mu'], means['N_Y_2_sigma']),
+ 'N_Y_3': dist.Normal(means['N_Y_3_mu'], means['N_Y_3_sigma']),
+ 'N_Y_4': dist.Normal(means['N_Y_4_mu'], means['N_Y_4_sigma']),
+ 'N_Y_5': dist.Normal(means['N_Y_5_mu'], means['N_Y_5_sigma']),
+ }
+ return updated_noise
+
+ def __call__(self):
+ return self.model(self.init_noise)
+
+"""## Sanity check: 1
+### Making sure VAE works
+"""
+
+# Generate an instance of dSprites image
+ox, y = get_specific_data(cuda=True)
+plot_image(ox)
+# Pass it through VAE to get q(z|x) => N(mu, sigma)
+mu, sigma = vae.encoder.forward(ox,vae.remap_y(y))
+# Feed these params to our custom SCM
+scm = SCM(vae, mu.cpu(), sigma.cpu())
+print(y)
+# Check for reconstruction
+
+"""## Sanity check 2
+
+### To check if the decoder is able to generate the image if the latents are changed:
+#### To achieve this we manually change the labels in the code and run it through the decoder and check for reconstruction
+"""
+
+original, y_original = get_specific_data(cuda=True)
+print('top: ',y_original)
+mu, sigma = vae.encoder.forward(original,vae.remap_y(y_original))
+B = 100
+zs = torch.cat([dist.Normal(mu.cpu(), sigma.cpu()).sample() for a in range(B)], 0)
+ys = torch.cat([vae.remap_y(y_original) for a in range(B)], 0)
+rs = vae.decoder.forward(zs.cuda(), ys).detach()
+compare_to_density(original,rs)
+
+y_new = torch.tensor(y_original)
+y_new[0,1] = (y_original[0,1] + 1) % 2
+print('bottom: ', y_new)
+zs = torch.cat([dist.Normal(mu.cpu(), sigma.cpu()).sample() for a in range(B)], 0)
+ys = torch.cat([vae.remap_y(y_new) for a in range(B)], 0)
+rs = vae.decoder.forward(zs.cuda(), ys).detach()
+compare_to_density(original,rs)
+
+"""## Conditioning the model"""
+
+cond_data = {}
+for i in range(1, 6):
+ cond_data["Y_{}".format(i)] = torch.tensor(y[0,i].cpu()).to(torch.float32)
+print(cond_data)
+
+# cond_data['Y_1'] = torch.tensor(1.)
+# cond_data['Y_2'] = torch.tensor(4.)
+conditioned_model = pyro.condition(scm.model, data=cond_data)
+cond_noise = scm.update_noise_svi(cond_data)
+print(cond_data)
+
+rxs = []
+for i in range(100):
+ (rx,ry,_), _ = scm.model(cond_noise)
+ rxs.append(rx)
+compare_to_density(ox, torch.cat(rxs))
+_ =plt.suptitle("SCM Conditioned on Original", fontsize=18, fontstyle='italic')
+
+"""## Counterfactuals"""
+
+# intervening on Shape, posX and PosY
+intervened_model = pyro.do(scm.model, data={
+ "Y_1": torch.tensor(0.),
+ "Y_4": torch.tensor(30.),
+ "Y_5": torch.tensor(25.),
+})
+noise_data = {}
+for term, d in cond_noise.items():
+ noise_data[term] = d.loc
+# intervened_noise = scm.update_noise_svi(noise_data, intervened_model)
+
+(rx1,ry,_), _ = intervened_model(scm.init_noise)
+compare_to_density(ox, rx1)
+print(ry)
+
+rxs = []
+for i in range(5000):
+ (cfo1,ny1,nz1), _= intervened_model(cond_noise)
+ rxs.append(cfo1)
+compare_to_density(ox, torch.cat(rxs))
+_ =plt.suptitle("SCM intervened on shape", fontsize=18, fontstyle='italic')
+
+"""## ======== Test code ======="""
+
+import warnings
+warnings.filterwarnings("ignore")
+
+xs, ys = next(data_iter)
+x = xs[0].reshape(1,-1).cuda()
+y = ys[0].reshape(1,-1).cuda()
+mu, sigma = vae.encoder.forward(x,vae.remap_y(y))
+mu = mu.cpu()
+sigma = sigma.cpu()
+recon_x1, y1, z1 = OldSCM(vae, mu, sigma)
+print(y1)
+
+def OldSCM(vae, mu, sigma):
+ z_dim = vae.z_dim
+ Ny, Y, ys = [], [], []
+ Nx = pyro.sample("Nx", dist.Uniform(torch.zeros(vae.image_dim), torch.ones(vae.image_dim)))
+ Nz = pyro.sample("Nz", dist.Normal(torch.zeros(z_dim), torch.ones(z_dim)))
+ m = torch.distributions.gumbel.Gumbel(torch.tensor(0.0), torch.tensor(1.0))
+ for label_id in range(6):
+ name = vae.label_names[label_id]
+ length = vae.label_shape[label_id]
+ new = pyro.sample("Ny_%s"%name, dist.Uniform(torch.zeros(length), torch.ones(length)) )
+ Ny.append(new)
+ gumbel_vars = torch.tensor([m.sample() for _ in range(length)])
+ max_ind = torch.argmax(torch.log(new) + gumbel_vars).item()
+ Y.append(pyro.sample("Y_%s"%name, dist.Normal(torch.tensor(max_ind * 1.0), 1e-4)))
+# Y.append(pyro.sample("Y_%s"%name, dist.Delta(torch.tensor(max_ind*1.0))))
+ ys.append(torch.nn.functional.one_hot(torch.tensor(max_ind), int(length)))
+ Y = torch.tensor(Y)
+ ys = torch.cat(ys).to(torch.float32).reshape(1,-1).cuda()
+ Z = pyro.sample("Z", dist.Normal(mu + Nz*sigma, 1e-4))
+# Z = pyro.sample("Z", dist.Delta(mu + Nz*sigma))
+ zs = Z.cuda()
+ p = vae.decoder.forward(zs,ys)
+ X = pyro.sample("X", dist.Normal((Nx < p.cpu()).type(torch.float), 1e-4))
+# X = pyro.sample("X", dist.Delta((Nx < p.cpu()).type(torch.float)))
+ return X, Y, Z
+
+from pyro.infer.importance import Importance
+from pyro.infer.mcmc import MCMC
+from pyro.infer.mcmc.nuts import HMC
+from pyro.infer import SVI
+
+intervened_model = pyro.do(OldSCM, data={"Y_shape": torch.tensor(0.)})
+conditioned_model = pyro.condition(OldSCM, data={
+ "X": recon_x1,
+# "Z": z1,
+ "Y_color":y1[0],
+ "Y_shape": y1[1],
+ "Y_scale": y1[2],
+ "Y_orientation": y1[3],
+ "Y_posX": y1[4],
+ "Y_posY": y1[5]
+ })
+
+# change to svi
+posterior = pyro.infer.Importance(conditioned_model, num_samples = 100)
+posterior.run(vae, mu, sigma)
+
+result = []
+for i in range(500):
+ trace = posterior()
+ x = trace.nodes['Nx']['value']
+ ny_shape = trace.nodes['Ny_shape']['value']
+ ny_scale = trace.nodes['Ny_scale']['value']
+ ny_orientation = trace.nodes['Ny_orientation']['value']
+ ny_posX = trace.nodes['Ny_posX']['value']
+ ny_posY = trace.nodes['Ny_posY']['value']
+ z = trace.nodes['Nz']['value']
+ con_obj = pyro.condition(intervened_model, data = {
+ "Nx": x,
+ "Ny_shape": ny_shape,
+ "Ny_scale": ny_scale,
+ "Ny_orientation": ny_orientation,
+ "Ny_posX": ny_posX,
+ "Ny_posY": ny_posY,
+ "Nz": z
+ })
+
+recon_x2,y2,z2 = con_obj(vae, mu, sigma)
+print(y2)
+recon_check(recon_x1.reshape(-1, 64, 64)[0], recon_x2.reshape(-1, 64, 64)[0])
+
+result = []
+for i in range(500):
+ trace = posterior()
+ x = trace.nodes['Nx']['value']
+ ny_shape = trace.nodes['Ny_shape']['value']
+ ny_scale = trace.nodes['Ny_scale']['value']
+ ny_orientation = trace.nodes['Ny_orientation']['value']
+ ny_posX = trace.nodes['Ny_posX']['value']
+ ny_posY = trace.nodes['Ny_posY']['value']
+ z = trace.nodes['Nz']['value']
+ con_obj = pyro.condition(intervened_model, data = {
+ "Nx": x,
+ "Ny_shape": ny_shape,
+ "Ny_scale": ny_scale,
+ "Ny_orientation": ny_orientation,
+ "Ny_posX": ny_posX,
+ "Ny_posY": ny_posY,
+ "Nz": z
+ })
+
+recon_x2,y2,z2 = con_obj(vae, mu, sigma)
+print(y2)
+recon_check(recon_x1.reshape(-1, 64, 64)[0], recon_x2.reshape(-1, 64, 64)[0])
+
+vae.label_names
+
+adam_params = {"lr": 0.00042, "betas": (0.9, 0.999)}
+optimizer = Adam(adam_params)
+# set up the loss(es) for inference. wrapping the guide in config_enumerate builds
+# the loss as a sum
+# by enumerating each class label for the sampled discrete categorical distribution
+# in the model
+guide = config_enumerate(guide, "parallel", expand=True)
+elbo = Trace_ELBO(max_plate_nesting=1)
+loss_basic = SVI(sup_vae.model, guide, optimizer, loss=elbo)
+
+# build a list of all losses considered
+losses = [loss_basic]
+loss_aux = SVI(sup_vae.model_classify, sup_vae.guide_classify, optimizer, loss=elbo)
+losses.append(loss_aux)
+
+
+
+from pyro.optim import Adam
+from pyro.infer import SVI, Trace_ELBO
+
+adam = Adam({"lr": 0.005, "betas": (0.90, 0.999)})
+svi = SVI(vae.model, vae.guide, adam, loss=Trace_ELBO())
+
+
+param_vals = []
+pbar = tqdm(range(5))
+for _ in pbar:
+ xs, ys = next(data_iter)
+ xs = xs.cuda()
+ ys = ys.cuda()
+ svi.step(xs, ys)
+# param_vals.append({k: param(k).item() for k in ["fl", "ia"]})
+
+# pd.DataFrame(param_vals).plot(subplots=True)
+
diff --git a/deep-fakes.jpg b/deep-fakes.jpg
new file mode 100644
index 0000000..6541c59
Binary files /dev/null and b/deep-fakes.jpg differ
diff --git a/figs/cs-vae-1.png b/figs/cs-vae-1.png
new file mode 100644
index 0000000..0c0ab7c
Binary files /dev/null and b/figs/cs-vae-1.png differ
diff --git a/figs/dag.png b/figs/dag.png
new file mode 100644
index 0000000..afdc92d
Binary files /dev/null and b/figs/dag.png differ
diff --git a/figs/deep-fakes.jpg b/figs/deep-fakes.jpg
new file mode 100644
index 0000000..6541c59
Binary files /dev/null and b/figs/deep-fakes.jpg differ
diff --git a/figs/intervention-2.png b/figs/intervention-2.png
new file mode 100644
index 0000000..9b4b489
Binary files /dev/null and b/figs/intervention-2.png differ
diff --git a/figs/intervention.png b/figs/intervention.png
new file mode 100644
index 0000000..26350e9
Binary files /dev/null and b/figs/intervention.png differ
diff --git a/figs/manual-intervention.png b/figs/manual-intervention.png
new file mode 100644
index 0000000..daaa308
Binary files /dev/null and b/figs/manual-intervention.png differ
diff --git a/figs/scm-conditioned.png b/figs/scm-conditioned.png
new file mode 100644
index 0000000..fdb14df
Binary files /dev/null and b/figs/scm-conditioned.png differ
diff --git a/figs/scm.png b/figs/scm.png
new file mode 100644
index 0000000..4953807
Binary files /dev/null and b/figs/scm.png differ
diff --git a/figs/vae capturing latent.png b/figs/vae capturing latent.png
new file mode 100644
index 0000000..a18b095
Binary files /dev/null and b/figs/vae capturing latent.png differ
diff --git a/figs/vae-recons.png b/figs/vae-recons.png
new file mode 100644
index 0000000..23a565f
Binary files /dev/null and b/figs/vae-recons.png differ
diff --git a/projects.pdf b/projects.pdf
new file mode 100644
index 0000000..b9db080
Binary files /dev/null and b/projects.pdf differ
diff --git a/vae_ajeya.ipynb b/vae_ajeya.ipynb
new file mode 100644
index 0000000..3cb3a1e
--- /dev/null
+++ b/vae_ajeya.ipynb
@@ -0,0 +1,1135 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "vae-ajeya.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.1"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "6rtbzfmeMGvr",
+ "colab_type": "code",
+ "outputId": "5511d557-1110-448e-ad34-ce6095067c64",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 496
+ }
+ },
+ "source": [
+ "!pip3 install pyro-ppl\n",
+ "!pip3 install torch torchvision\n",
+ "!pip3 install pydrive --upgrade"
+ ],
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: pyro-ppl in /usr/local/lib/python3.6/dist-packages (0.3.4)\n",
+ "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from pyro-ppl) (1.12.0)\n",
+ "Requirement already satisfied: contextlib2 in /usr/local/lib/python3.6/dist-packages (from pyro-ppl) (0.5.5)\n",
+ "Requirement already satisfied: tqdm>=4.31 in /usr/local/lib/python3.6/dist-packages (from pyro-ppl) (4.32.2)\n",
+ "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from pyro-ppl) (2.3.2)\n",
+ "Requirement already satisfied: torch>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from pyro-ppl) (1.1.0)\n",
+ "Requirement already satisfied: graphviz>=0.8 in /usr/local/lib/python3.6/dist-packages (from pyro-ppl) (0.10.1)\n",
+ "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.6/dist-packages (from pyro-ppl) (1.16.4)\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.1.0)\n",
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.3.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.16.4)\n",
+ "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.12.0)\n",
+ "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (4.3.0)\n",
+ "Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision) (0.46)\n",
+ "Requirement already up-to-date: pydrive in /usr/local/lib/python3.6/dist-packages (1.3.1)\n",
+ "Requirement already satisfied, skipping upgrade: oauth2client>=4.0.0 in /usr/local/lib/python3.6/dist-packages (from pydrive) (4.1.3)\n",
+ "Requirement already satisfied, skipping upgrade: PyYAML>=3.0 in /usr/local/lib/python3.6/dist-packages (from pydrive) (3.13)\n",
+ "Requirement already satisfied, skipping upgrade: google-api-python-client>=1.2 in /usr/local/lib/python3.6/dist-packages (from pydrive) (1.7.9)\n",
+ "Requirement already satisfied, skipping upgrade: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.0.0->pydrive) (0.11.3)\n",
+ "Requirement already satisfied, skipping upgrade: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.0.0->pydrive) (4.0)\n",
+ "Requirement already satisfied, skipping upgrade: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.0.0->pydrive) (0.4.5)\n",
+ "Requirement already satisfied, skipping upgrade: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.0.0->pydrive) (0.2.5)\n",
+ "Requirement already satisfied, skipping upgrade: six>=1.6.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.0.0->pydrive) (1.12.0)\n",
+ "Requirement already satisfied, skipping upgrade: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.2->pydrive) (1.4.2)\n",
+ "Requirement already satisfied, skipping upgrade: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.2->pydrive) (3.0.0)\n",
+ "Requirement already satisfied, skipping upgrade: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.2->pydrive) (0.0.3)\n",
+ "Requirement already satisfied, skipping upgrade: cachetools>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.2->pydrive) (3.1.1)\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "OouaeUpQMEjN",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from __future__ import absolute_import\n",
+ "from __future__ import division\n",
+ "from __future__ import print_function\n",
+ "from matplotlib import pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "\n",
+ "import os\n",
+ "\n",
+ "import torch\n",
+ "import torchvision.datasets as dset\n",
+ "import torch.nn as nn\n",
+ "import torchvision.transforms as transforms\n",
+ "\n",
+ "import pyro\n",
+ "from pyro.contrib.examples.util import print_and_log\n",
+ "import pyro.distributions as dist\n",
+ "from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate\n",
+ "from pyro.optim import Adam\n",
+ "\n",
+ "# Change figure aesthetics\n",
+ "%matplotlib inline\n",
+ "sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})\n",
+ "\n",
+ "USE_CUDA = True\n",
+ "\n",
+ "pyro.enable_validation(True)\n",
+ "pyro.distributions.enable_validation(False)\n",
+ "\n",
+ "# from custom_mlp import MLP, Expxt"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "s1p7kozIT9YN",
+ "colab_type": "code",
+ "outputId": "b43bca1d-270d-4b0f-a5db-72480fc19511",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.test.gpu_device_name()"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "'/device:GPU:0'"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ux30_pwSUEJX",
+ "colab_type": "code",
+ "outputId": "84958572-1246-4eab-f56b-118092b21e7a",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 102
+ }
+ },
+ "source": [
+ "# Hack to get all available GPU ram.\n",
+ "!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi\n",
+ "!pip install gputil\n",
+ "!pip install psutil\n",
+ "!pip install humanize\n",
+ "import psutil\n",
+ "import humanize\n",
+ "import os\n",
+ "import GPUtil as GPU\n",
+ "GPUs = GPU.getGPUs()\n",
+ "# XXX: only one GPU on Colab and isn’t guaranteed\n",
+ "gpu = GPUs[0]\n",
+ "def printm():\n",
+ " process = psutil.Process(os.getpid())\n",
+ " print(\"Gen RAM Free: \" + humanize.naturalsize( psutil.virtual_memory().available ), \" | Proc size: \" + humanize.naturalsize( process.memory_info().rss))\n",
+ " print(\"GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB\".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))\n",
+ "printm()"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: gputil in /usr/local/lib/python3.6/dist-packages (1.4.0)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.6/dist-packages (5.4.8)\n",
+ "Requirement already satisfied: humanize in /usr/local/lib/python3.6/dist-packages (0.5.1)\n",
+ "Gen RAM Free: 10.6 GB | Proc size: 3.5 GB\n",
+ "GPU RAM Free: 14310MB | Used: 769MB | Util 5% | Total 15079MB\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "BPXd84WMYr3W",
+ "colab_type": "code",
+ "outputId": "62d05def-f7f5-4b45-ee41-b9632f8a7034",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/gdrive')"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "0GhM07OkM4_9",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "\n",
+ "\n",
+ "from pydrive.auth import GoogleAuth\n",
+ "from pydrive.drive import GoogleDrive\n",
+ "from google.colab import auth\n",
+ "from oauth2client.client import GoogleCredentials\n",
+ "\n",
+ "auth.authenticate_user()\n",
+ "gauth = GoogleAuth()\n",
+ "gauth.credentials = GoogleCredentials.get_application_default()\n",
+ "drive = GoogleDrive(gauth)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Y8E4_2q1NTZt",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# https://drive.google.com/open?id=1mAtimHIWHJM2UJNxsIylPyI1jUebGIx-\n",
+ "custom_mlp_module = drive.CreateFile({'id':'1mAtimHIWHJM2UJNxsIylPyI1jUebGIx-'})"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "R-A-TTIFPfip",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "custom_mlp = custom_mlp_module.GetContentFile('custom_mlp.py')\n",
+ "from custom_mlp import MLP, Exp"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Vomid1y5MEjP",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class SVAE(nn.Module):\n",
+ " \"\"\"\n",
+ " This class encapsulates the parameters (neural networks) and models & guides needed to train a\n",
+ " semi-supervised variational auto-encoder on the Dsprites image dataset\n",
+ "\n",
+ " :param output_size: size of the tensor representing the class label \n",
+ " :param input_size: size of the tensor representing the image (64*64 = 4096 for our Dsprites dataset\n",
+ " since we flatten the images and scale the pixels to be in [0,1])\n",
+ " :param z_dim: size of the tensor representing the latent random variable z\n",
+ " (for our Dsprites dataset)\n",
+ " :param hidden_layers: a tuple (or list) of MLP layers to be used in the neural networks\n",
+ " representing the parameters of the distributions in our model\n",
+ " :param use_cuda: use GPUs for faster training\n",
+ " :param aux_loss_multiplier: the multiplier to use with the auxiliary loss\n",
+ " \"\"\"\n",
+ " def __init__(self, output_size=6, input_size=4096, z_dim=50, hidden_layers=(500,),\n",
+ " config_enum=None, use_cuda=USE_CUDA, aux_loss_multiplier=None):\n",
+ "\n",
+ " super(SVAE, self).__init__()\n",
+ "\n",
+ " # initialize the class with all arguments provided to the constructor\n",
+ " self.output_size = output_size\n",
+ " self.input_size = input_size\n",
+ " self.z_dim = z_dim\n",
+ " self.hidden_layers = hidden_layers\n",
+ " self.allow_broadcast = config_enum == 'parallel'\n",
+ " self.use_cuda = use_cuda\n",
+ " self.aux_loss_multiplier = aux_loss_multiplier\n",
+ "\n",
+ " # define and instantiate the neural networks representing\n",
+ " # the paramters of various distributions in the model\n",
+ " self.setup_networks()\n",
+ "\n",
+ " def setup_networks(self):\n",
+ "\n",
+ " z_dim = self.z_dim\n",
+ " hidden_sizes = self.hidden_layers\n",
+ "\n",
+ " # define the neural networks used later in the model and the guide.\n",
+ " # these networks are MLPs (multi-layered perceptrons or simple feed-forward networks)\n",
+ " # where the provided activation parameter is used on every linear layer except\n",
+ " # for the output layer where we use the provided output_activation parameter\n",
+ " self.encoder_y = MLP([self.input_size] + hidden_sizes + [self.output_size],\n",
+ " activation=nn.Softplus,\n",
+ " output_activation=nn.Softmax,\n",
+ " allow_broadcast=self.allow_broadcast,\n",
+ " use_cuda=self.use_cuda)\n",
+ "\n",
+ " # a split in the final layer's size is used for multiple outputs\n",
+ " # and potentially applying separate activation functions on them\n",
+ " # e.g. in this network the final output is of size [z_dim,z_dim]\n",
+ " # to produce loc and scale, and apply different activations [None,Exp] on them\n",
+ " self.encoder_z = MLP([self.input_size + self.output_size] +\n",
+ " hidden_sizes + [[z_dim, z_dim]],\n",
+ " activation=nn.Softplus,\n",
+ " output_activation=[None, Exp],\n",
+ " allow_broadcast=self.allow_broadcast,\n",
+ " use_cuda=self.use_cuda)\n",
+ "\n",
+ " self.decoder = MLP([z_dim + self.output_size] +\n",
+ " hidden_sizes + [self.input_size],\n",
+ " activation=nn.Softplus,\n",
+ " output_activation=nn.Sigmoid,\n",
+ " allow_broadcast=self.allow_broadcast,\n",
+ " use_cuda=self.use_cuda)\n",
+ "\n",
+ " # using GPUs for faster training of the networks\n",
+ " if self.use_cuda:\n",
+ " self.cuda()\n",
+ "\n",
+ " def model(self, xs, ys=None):\n",
+ " \"\"\"\n",
+ " The model corresponds to the following generative process:\n",
+ " p(z) = normal(0,I) # dsprites label (latent)\n",
+ " p(y|x) = categorical(I/10.) # which digit (supervised)\n",
+ " p(x|y,z) = bernoulli(loc(y,z)) # an image\n",
+ " loc is given by a neural network `decoder`\n",
+ "\n",
+ " :param xs: a batch of scaled vectors of pixels from an image\n",
+ " :param ys: (optional) a batch of the class labels i.e.\n",
+ " the digit corresponding to the image(s)\n",
+ " :return: None\n",
+ " \"\"\"\n",
+ " # register this pytorch module and all of its sub-modules with pyro\n",
+ " pyro.module(\"ss_vae\", self)\n",
+ "\n",
+ " batch_size = xs.size(0)\n",
+ " options = dict(dtype=xs.dtype, device=xs.device)\n",
+ " with pyro.plate(\"data\"):\n",
+ "\n",
+ " # sample the handwriting style from the constant prior distribution\n",
+ " prior_loc = torch.zeros(batch_size, self.z_dim, **options)\n",
+ " prior_scale = torch.ones(batch_size, self.z_dim, **options)\n",
+ " zs = pyro.sample(\"z\", dist.Normal(prior_loc, prior_scale).to_event(1))\n",
+ "\n",
+ " # if the label y (which digit to write) is supervised, sample from the\n",
+ " # constant prior, otherwise, observe the value (i.e. score it against the constant prior)\n",
+ " alpha_prior = torch.ones(batch_size, self.output_size, **options) / (1.0 * self.output_size)\n",
+ " ys = pyro.sample(\"y\", dist.OneHotCategorical(alpha_prior), obs=ys)\n",
+ "\n",
+ " # finally, score the image (x) using the handwriting style (z) and\n",
+ " # the class label y (which digit to write) against the\n",
+ " # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))\n",
+ " # where `decoder` is a neural network\n",
+ " loc = self.decoder.forward([zs, ys])\n",
+ " pyro.sample(\"x\", dist.Bernoulli(loc).to_event(1), obs=xs)\n",
+ " # return the loc so we can visualize it later\n",
+ " return loc\n",
+ "\n",
+ " def guide(self, xs, ys=None):\n",
+ " \"\"\"\n",
+ " The guide corresponds to the following:\n",
+ " q(y|x) = categorical(alpha(x)) # infer label from an image\n",
+ " q(z|x,y) = normal(loc(x,y),scale(x,y)) # infer latent class from an image and the label\n",
+ " loc, scale are given by a neural network `encoder_z`\n",
+ " alpha is given by a neural network `encoder_y`\n",
+ "\n",
+ " :param xs: a batch of scaled vectors of pixels from an image\n",
+ " :param ys: (optional) a batch of the class labels i.e.\n",
+ " the digit corresponding to the image(s)\n",
+ " :return: None\n",
+ " \"\"\"\n",
+ " # inform Pyro that the variables in the batch of xs, ys are conditionally independent\n",
+ " with pyro.plate(\"data\"):\n",
+ "\n",
+ " # if the class label (the digit) is not supervised, sample\n",
+ " # (and score) the digit with the variational distribution\n",
+ " # q(y|x) = categorical(alpha(x))\n",
+ " if ys is None: \n",
+ " alpha = self.encoder_y.forward(xs)\n",
+ " ys = pyro.sample(\"y\", dist.OneHotCategorical(alpha))\n",
+ "\n",
+ " # sample (and score) the latent handwriting-style with the variational\n",
+ " # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))\n",
+ " loc, scale = self.encoder_z.forward([xs, ys])\n",
+ " pyro.sample(\"z\", dist.Normal(loc, scale).to_event(1))\n",
+ "\n",
+ " def classifier(self, xs):\n",
+ " \"\"\"\n",
+ " classify an image (or a batch of images)\n",
+ "\n",
+ " :param xs: a batch of scaled vectors of pixels from an image\n",
+ " :return: a batch of the corresponding class labels (as one-hots)\n",
+ " \"\"\"\n",
+ " # use the trained model q(y|x) = categorical(alpha(x))\n",
+ " # compute all class probabilities for the image(s)\n",
+ " alpha = self.encoder_y.forward(xs)\n",
+ "\n",
+ " # get the index (digit) that corresponds to\n",
+ " # the maximum predicted class probability\n",
+ " res, ind = torch.topk(alpha, 1)\n",
+ "\n",
+ " # convert the digit(s) to one-hot tensor(s)\n",
+ " ys = torch.zeros_like(alpha).scatter_(1, ind, 1.0)\n",
+ " return ys\n",
+ "\n",
+ " def model_classify(self, xs, ys=None):\n",
+ " \"\"\"\n",
+ " this model is used to add an auxiliary (supervised) loss as described in the\n",
+ " Kingma et al., \"Semi-Supervised Learning with Deep Generative Models\". It \n",
+ " probably isn't needed here.\n",
+ " \"\"\"\n",
+ " # register all pytorch (sub)modules with pyro\n",
+ " pyro.module(\"ss_vae\", self)\n",
+ "\n",
+ " # inform Pyro that the variables in the batch of xs, ys are conditionally independent\n",
+ " with pyro.plate(\"data\"):\n",
+ " # this here is the extra term to yield an auxiliary loss that we do gradient descent on\n",
+ " if ys is not None:\n",
+ " alpha = self.encoder_y.forward(xs)\n",
+ " with pyro.poutine.scale(scale=self.aux_loss_multiplier):\n",
+ " pyro.sample(\"y_aux\", dist.OneHotCategorical(alpha), obs=ys)\n",
+ "\n",
+ " def guide_classify(self, xs, ys=None):\n",
+ " \"\"\"\n",
+ " dummy guide function to accompany model_classify in inference\n",
+ " \"\"\"\n",
+ " pass"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qPdgbk1kMEjR",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def setup_data_loaders(train_x, train_y, test_x, test_y, batch_size=128, use_cuda=USE_CUDA):\n",
+ " train_dset = torch.utils.data.TensorDataset(\n",
+ " torch.from_numpy(train_x.astype(np.float32)).reshape(-1, 4096),\n",
+ " torch.from_numpy(train_y.astype(np.float32))\n",
+ " )\n",
+ " test_dset = torch.utils.data.TensorDataset(\n",
+ " torch.from_numpy(test_x.astype(np.float32)).reshape(-1, 4096),\n",
+ " torch.from_numpy(test_y.astype(np.float32))\n",
+ " ) \n",
+ " kwargs = {'num_workers': 1, 'pin_memory': use_cuda}\n",
+ " loader = {}\n",
+ " loader[\"sup\"] = torch.utils.data.DataLoader(\n",
+ " dataset=train_dset, batch_size=batch_size, shuffle=False, **kwargs\n",
+ " )\n",
+ " test_loader = torch.utils.data.DataLoader(\n",
+ " dataset=test_dset, batch_size=batch_size, shuffle=False, **kwargs\n",
+ " )\n",
+ " loader[\"valid\"] = test_loader\n",
+ " loader[\"test\"] = test_loader\n",
+ " return loader\n",
+ "\n",
+ "def run_supervized_inference_for_epoch(data_loaders, losses):\n",
+ " \"\"\"\n",
+ " runs the inference algorithm for an epoch\n",
+ " returns the values of all losses separately on supervised and unsupervised parts\n",
+ " \"\"\"\n",
+ " num_losses = len(losses)\n",
+ "\n",
+ " # compute number of batches for an epoch\n",
+ " batches_per_epoch = len(data_loaders[\"sup\"])\n",
+ "\n",
+ " # initialize variables to store loss values\n",
+ " epoch_losses = [0.] * num_losses\n",
+ "\n",
+ " # setup the iterators for training data loaders\n",
+ " sup_iter = iter(data_loaders[\"sup\"])\n",
+ "\n",
+ " for i in range(batches_per_epoch):\n",
+ "\n",
+ " # extract the corresponding batch\n",
+ " (xs, ys) = next(sup_iter)\n",
+ " xs = xs.cuda()\n",
+ " ys = ys.cuda()\n",
+ " # run the inference for each loss with supervised or un-supervised\n",
+ " # data as arguments\n",
+ " for loss_id in range(num_losses):\n",
+ " new_loss = losses[loss_id].step(xs, ys)\n",
+ " epoch_losses[loss_id] += new_loss\n",
+ "\n",
+ " # return the values of all losses\n",
+ " return epoch_losses\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Ad6H4X0HMEjT",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "dataset_zip = np.load(\n",
+ " '/content/gdrive/My Drive/data-science/causal-ml/projects/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',\n",
+ " encoding = 'bytes',\n",
+ " allow_pickle=True\n",
+ ")\n",
+ "imgs = dataset_zip['imgs']\n",
+ "latents_values = dataset_zip['latents_values']\n",
+ "latents_classes = dataset_zip['latents_classes']\n",
+ "metadata = dataset_zip['metadata'][()]\n",
+ "\n",
+ "latents_sizes = metadata[b'latents_sizes']\n",
+ "latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:],\n",
+ " np.array([1,])))\n",
+ "\n",
+ "def latent_to_index(latents):\n",
+ " return np.dot(latents, latents_bases).astype(int)\n",
+ "\n",
+ "\n",
+ "def sample_latent(size=1):\n",
+ " samples = np.zeros((size, latents_sizes.size))\n",
+ " for lat_i, lat_size in enumerate(latents_sizes):\n",
+ " samples[:, lat_i] = np.random.randint(lat_size, size=size)\n",
+ "\n",
+ " return samples\n",
+ "\n",
+ "# Sample latents randomly\n",
+ "latents_sampled = sample_latent(size=70000)\n",
+ "\n",
+ "# Select images\n",
+ "indices_sampled = latent_to_index(latents_sampled)\n",
+ "imgs_sampled = imgs[indices_sampled]\n",
+ "\n",
+ "data_loaders = setup_data_loaders(\n",
+ " imgs_sampled[1000:], latents_sampled[1000:],\n",
+ " imgs_sampled[:1000], latents_sampled[:1000],\n",
+ " batch_size=256,\n",
+ " use_cuda=USE_CUDA\n",
+ ")"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Xvn38eq0MEjU",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def get_accuracy(data_loader, classifier_fn, batch_size):\n",
+ " \"\"\"\n",
+ " compute the accuracy over the supervised training set or the testing set\n",
+ " \"\"\"\n",
+ " predictions, actuals = [], []\n",
+ "\n",
+ " # use the appropriate data loader\n",
+ " \n",
+ " for (xs, ys) in data_loader:\n",
+ " xs = xs.cuda()\n",
+ " ys = ys.cuda()\n",
+ " # use classification function to compute all predictions for each batch\n",
+ " predictions.append(classifier_fn(xs))\n",
+ " actuals.append(ys)\n",
+ "\n",
+ " # compute the number of accurate predictions\n",
+ " accurate_preds = 0\n",
+ " for pred, act in zip(predictions, actuals):\n",
+ " for i in range(pred.size(0)):\n",
+ " v = torch.sum(pred[i] == act[i])\n",
+ " accurate_preds += (v.item() == 10)\n",
+ "\n",
+ " # calculate the accuracy between 0 and 1\n",
+ " accuracy = (accurate_preds * 1.0) / (len(predictions) * batch_size)\n",
+ " return accuracy\n",
+ "\n",
+ "\n",
+ "def visualize(s_vae, viz, test_loader):\n",
+ " if viz:\n",
+ " plot_conditional_samples_ssvae(s_vae, viz)\n",
+ " mnist_test_tsne_ssvae(ssvae=s_vae, test_loader=test_loader)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_pz3enLBMEjW",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "sup_vae = SVAE(\n",
+ " output_size=6,\n",
+ " input_size=4096,\n",
+ " z_dim=50,\n",
+ " hidden_layers=[500],\n",
+ " use_cuda=USE_CUDA,\n",
+ " config_enum=\"parallel\",\n",
+ " aux_loss_multiplier=46\n",
+ ")"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "T4kr9m3UMEja",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "adam_params = {\"lr\": 0.00042, \"betas\": (0.9, 0.999)}\n",
+ "optimizer = Adam(adam_params)\n",
+ "# set up the loss(es) for inference. wrapping the guide in config_enumerate builds\n",
+ "# the loss as a sum\n",
+ "# by enumerating each class label for the sampled discrete categorical distribution\n",
+ "# in the model\n",
+ "guide = config_enumerate(sup_vae.guide, \"parallel\", expand=True)\n",
+ "elbo = Trace_ELBO(max_plate_nesting=1)\n",
+ "loss_basic = SVI(sup_vae.model, guide, optimizer, loss=elbo)\n",
+ "\n",
+ "# build a list of all losses considered\n",
+ "losses = [loss_basic]\n",
+ "loss_aux = SVI(sup_vae.model_classify, sup_vae.guide_classify, optimizer, loss=elbo)\n",
+ "losses.append(loss_aux)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "r0dRPWlWeG64",
+ "colab_type": "code",
+ "outputId": "270c8c1e-5d1d-4af3-9072-5b503f0f88f0",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 731
+ }
+ },
+ "source": [
+ "sup_vae"
+ ],
+ "execution_count": 17,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "SVAE(\n",
+ " (encoder_y): MLP(\n",
+ " (sequential_mlp): Sequential(\n",
+ " (0): ConcatModule()\n",
+ " (1): DataParallel(\n",
+ " (module): Linear(in_features=4096, out_features=500, bias=True)\n",
+ " )\n",
+ " (2): Softplus(beta=1, threshold=20)\n",
+ " (3): Linear(in_features=500, out_features=6, bias=True)\n",
+ " (4): Softmax()\n",
+ " )\n",
+ " )\n",
+ " (encoder_z): MLP(\n",
+ " (sequential_mlp): Sequential(\n",
+ " (0): ConcatModule()\n",
+ " (1): DataParallel(\n",
+ " (module): Linear(in_features=4102, out_features=500, bias=True)\n",
+ " )\n",
+ " (2): Softplus(beta=1, threshold=20)\n",
+ " (3): ListOutModule(\n",
+ " (0): Sequential(\n",
+ " (0): Linear(in_features=500, out_features=50, bias=True)\n",
+ " )\n",
+ " (1): Sequential(\n",
+ " (0): Linear(in_features=500, out_features=50, bias=True)\n",
+ " (1): Exp()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (decoder): MLP(\n",
+ " (sequential_mlp): Sequential(\n",
+ " (0): ConcatModule()\n",
+ " (1): DataParallel(\n",
+ " (module): Linear(in_features=56, out_features=500, bias=True)\n",
+ " )\n",
+ " (2): Softplus(beta=1, threshold=20)\n",
+ " (3): Linear(in_features=500, out_features=4096, bias=True)\n",
+ " (4): Sigmoid()\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 17
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qfoFJZPVMEjb",
+ "colab_type": "code",
+ "outputId": "1c457d2c-6478-412c-99f6-36969a8173f5",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ }
+ },
+ "source": [
+ "sup_num = 60000\n",
+ "batch_size = 256\n",
+ "num_epochs = 100\n",
+ "\n",
+ "logger = open(\"./tmp.log\", \"w\")\n",
+ "# Number of supervised number\n",
+ "\n",
+ "# initializing local variables to maintain the best validation accuracy\n",
+ "# seen across epochs over the supervised training set\n",
+ "# and the corresponding testing set and the state of the networks\n",
+ "best_valid_acc, corresponding_test_acc = 0.0, 0.0\n",
+ "\n",
+ "# run inference for a certain number of epochs\n",
+ "for i in range(0, num_epochs):\n",
+ "\n",
+ " # get the losses for an epoch\n",
+ " epoch_losses_sup = run_supervized_inference_for_epoch(\n",
+ " data_loaders,\n",
+ " losses\n",
+ " )\n",
+ "\n",
+ " # compute average epoch losses i.e. losses per example\n",
+ " avg_epoch_losses_sup = map(lambda v: v / sup_num, epoch_losses_sup)\n",
+ "\n",
+ " # store the loss and validation/testing accuracies in the logfile\n",
+ " str_loss_sup = \" \".join(map(str, avg_epoch_losses_sup))\n",
+ "\n",
+ " str_print = \"{} epoch: avg loss {}\".format(i, \"{}\".format(str_loss_sup))\n",
+ "\n",
+ " validation_accuracy = get_accuracy(data_loaders[\"valid\"], sup_vae.classifier, batch_size)\n",
+ " str_print += \" validation accuracy {}\".format(validation_accuracy)\n",
+ "\n",
+ " # this test accuracy is only for logging, this is not used\n",
+ " # to make any decisions during training\n",
+ " test_accuracy = get_accuracy(data_loaders[\"test\"], sup_vae.classifier, batch_size)\n",
+ " str_print += \" test accuracy {}\".format(test_accuracy)\n",
+ "\n",
+ " # update the best validation accuracy and the corresponding\n",
+ " # testing accuracy and the state of the parent module (including the networks)\n",
+ " if best_valid_acc < validation_accuracy:\n",
+ " best_valid_acc = validation_accuracy\n",
+ " corresponding_test_acc = test_accuracy\n",
+ " \n",
+ " visualize(sup_vae, None, data_loaders[\"test\"])\n",
+ "\n",
+ " print_and_log(logger, str_print)\n",
+ "\n",
+ "final_test_accuracy = get_accuracy(data_loaders[\"test\"], sup_vae.classifier, batch_size)\n",
+ "print_and_log(logger, \"best validation accuracy {} corresponding testing accuracy {} \"\n",
+ " \"last testing accuracy {}\".format(best_valid_acc, corresponding_test_acc, final_test_accuracy))\n",
+ "\n",
+ "# close the logger file object if we opened it earlier\n",
+ "logger.close()"
+ ],
+ "execution_count": 36,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "0 epoch: avg loss 73.4194898127238 10.337283599853516 validation accuracy 0.0 test accuracy 0.0\n",
+ "1 epoch: avg loss 73.1135249282837 9.959152315266927 validation accuracy 0.0 test accuracy 0.0\n",
+ "2 epoch: avg loss 72.90056605784098 9.599737705485026 validation accuracy 0.0 test accuracy 0.0\n",
+ "3 epoch: avg loss 72.72961992340088 9.309077065022786 validation accuracy 0.0 test accuracy 0.0\n",
+ "4 epoch: avg loss 72.53395691884359 8.995561258951822 validation accuracy 0.0 test accuracy 0.0\n",
+ "5 epoch: avg loss 72.29241495920817 8.699117222086588 validation accuracy 0.0 test accuracy 0.0\n",
+ "6 epoch: avg loss 72.0966392674764 8.405718975830078 validation accuracy 0.0 test accuracy 0.0\n",
+ "7 epoch: avg loss 71.84623018544515 8.126094775390625 validation accuracy 0.0 test accuracy 0.0\n",
+ "8 epoch: avg loss 71.72414229482015 7.864214937337239 validation accuracy 0.0 test accuracy 0.0\n",
+ "9 epoch: avg loss 71.49476098785401 7.599710852050781 validation accuracy 0.0 test accuracy 0.0\n",
+ "10 epoch: avg loss 71.34665864410401 7.383735764567057 validation accuracy 0.0 test accuracy 0.0\n",
+ "11 epoch: avg loss 71.24757972157796 7.4970786661783855 validation accuracy 0.0 test accuracy 0.0\n",
+ "12 epoch: avg loss 71.0358139175415 7.385145209757487 validation accuracy 0.0 test accuracy 0.0\n",
+ "13 epoch: avg loss 70.79814100087484 6.938219022623698 validation accuracy 0.0 test accuracy 0.0\n",
+ "14 epoch: avg loss 70.690160206604 6.606387151082357 validation accuracy 0.0 test accuracy 0.0\n",
+ "15 epoch: avg loss 70.46839635569255 6.533454721069336 validation accuracy 0.0 test accuracy 0.0\n",
+ "16 epoch: avg loss 70.39087748362223 6.275628011067709 validation accuracy 0.0 test accuracy 0.0\n",
+ "17 epoch: avg loss 70.16816714019775 6.026749960327148 validation accuracy 0.0 test accuracy 0.0\n",
+ "18 epoch: avg loss 70.02727948557536 5.86120785929362 validation accuracy 0.0 test accuracy 0.0\n",
+ "19 epoch: avg loss 69.95993371734619 5.694548964436849 validation accuracy 0.0 test accuracy 0.0\n",
+ "20 epoch: avg loss 69.76063533681234 5.675880317179362 validation accuracy 0.0 test accuracy 0.0\n",
+ "21 epoch: avg loss 69.65645001780192 5.492274950154623 validation accuracy 0.0 test accuracy 0.0\n",
+ "22 epoch: avg loss 69.47397186838786 5.404019834391276 validation accuracy 0.0 test accuracy 0.0\n",
+ "23 epoch: avg loss 69.3778068456014 5.669653151448568 validation accuracy 0.0 test accuracy 0.0\n",
+ "24 epoch: avg loss 69.24922564442953 5.432655401611328 validation accuracy 0.0 test accuracy 0.0\n",
+ "25 epoch: avg loss 69.02797335764566 5.363767221069336 validation accuracy 0.0 test accuracy 0.0\n",
+ "26 epoch: avg loss 68.97380993804931 5.430132021077474 validation accuracy 0.0 test accuracy 0.0\n",
+ "27 epoch: avg loss 68.8614875096639 5.341963358561198 validation accuracy 0.0 test accuracy 0.0\n",
+ "28 epoch: avg loss 68.72352076975504 5.49090717569987 validation accuracy 0.0 test accuracy 0.0\n",
+ "29 epoch: avg loss 68.57598680165609 5.599548092651367 validation accuracy 0.0 test accuracy 0.0\n",
+ "30 epoch: avg loss 68.42424849599202 5.55614609375 validation accuracy 0.0 test accuracy 0.0\n",
+ "31 epoch: avg loss 68.29375264638264 5.479485826619466 validation accuracy 0.0 test accuracy 0.0\n",
+ "32 epoch: avg loss 68.1904503595988 5.2717307159423825 validation accuracy 0.0 test accuracy 0.0\n",
+ "33 epoch: avg loss 68.08769537913004 5.024212093098958 validation accuracy 0.0 test accuracy 0.0\n",
+ "34 epoch: avg loss 67.929296534729 4.744880869547526 validation accuracy 0.0 test accuracy 0.0\n",
+ "35 epoch: avg loss 67.83656690419515 4.545162293497722 validation accuracy 0.0 test accuracy 0.0\n",
+ "36 epoch: avg loss 67.77494875640869 4.303533932495117 validation accuracy 0.0 test accuracy 0.0\n",
+ "37 epoch: avg loss 67.67430809885661 4.098325904337565 validation accuracy 0.0 test accuracy 0.0\n",
+ "38 epoch: avg loss 67.48732435862223 3.956320953369141 validation accuracy 0.0 test accuracy 0.0\n",
+ "39 epoch: avg loss 67.43751215159098 3.842155121358236 validation accuracy 0.0 test accuracy 0.0\n",
+ "40 epoch: avg loss 67.3197662612915 3.7160155166625977 validation accuracy 0.0 test accuracy 0.0\n",
+ "41 epoch: avg loss 67.20157636871338 3.5679363571166993 validation accuracy 0.0 test accuracy 0.0\n",
+ "42 epoch: avg loss 67.11693908843993 3.5223314137776693 validation accuracy 0.0 test accuracy 0.0\n",
+ "43 epoch: avg loss 66.99766852366129 3.3062897201538086 validation accuracy 0.0 test accuracy 0.0\n",
+ "44 epoch: avg loss 66.94425606435139 3.274656864420573 validation accuracy 0.0 test accuracy 0.0\n",
+ "45 epoch: avg loss 66.87721158192953 3.2535403299967447 validation accuracy 0.0 test accuracy 0.0\n",
+ "46 epoch: avg loss 66.76631597646077 3.224982407124837 validation accuracy 0.0 test accuracy 0.0\n",
+ "47 epoch: avg loss 66.68057815093994 3.1721942459106445 validation accuracy 0.0 test accuracy 0.0\n",
+ "48 epoch: avg loss 66.53252608388264 2.9452966857910154 validation accuracy 0.0 test accuracy 0.0\n",
+ "49 epoch: avg loss 66.4851595067342 2.8821236943562827 validation accuracy 0.0 test accuracy 0.0\n",
+ "50 epoch: avg loss 66.37938383127849 3.09201468556722 validation accuracy 0.0 test accuracy 0.0\n",
+ "51 epoch: avg loss 66.31664015452067 3.177724250793457 validation accuracy 0.0 test accuracy 0.0\n",
+ "52 epoch: avg loss 66.19034068349202 2.9874297996520998 validation accuracy 0.0 test accuracy 0.0\n",
+ "53 epoch: avg loss 66.05708732248942 3.232627372741699 validation accuracy 0.0 test accuracy 0.0\n",
+ "54 epoch: avg loss 66.01276608225504 3.173233707173665 validation accuracy 0.0 test accuracy 0.0\n",
+ "55 epoch: avg loss 65.88708648427327 3.110121321105957 validation accuracy 0.0 test accuracy 0.0\n",
+ "56 epoch: avg loss 65.81027199045818 3.161735363260905 validation accuracy 0.0 test accuracy 0.0\n",
+ "57 epoch: avg loss 65.73718862457275 3.3055569646199543 validation accuracy 0.0 test accuracy 0.0\n",
+ "58 epoch: avg loss 65.66992208811442 3.2017634826660157 validation accuracy 0.0 test accuracy 0.0\n",
+ "59 epoch: avg loss 65.60472373199462 3.028648991394043 validation accuracy 0.0 test accuracy 0.0\n",
+ "60 epoch: avg loss 65.51507672678629 3.0213762514750164 validation accuracy 0.0 test accuracy 0.0\n",
+ "61 epoch: avg loss 65.38191290842693 3.0188735694885254 validation accuracy 0.0 test accuracy 0.0\n",
+ "62 epoch: avg loss 65.30616182607015 2.9055699577331544 validation accuracy 0.0 test accuracy 0.0\n",
+ "63 epoch: avg loss 65.21349840647379 2.852216015370687 validation accuracy 0.0 test accuracy 0.0\n",
+ "64 epoch: avg loss 65.18253641916911 2.813871074167887 validation accuracy 0.0 test accuracy 0.0\n",
+ "65 epoch: avg loss 65.12543276519776 2.7915143440246584 validation accuracy 0.0 test accuracy 0.0\n",
+ "66 epoch: avg loss 65.08151169586182 2.802122728474935 validation accuracy 0.0 test accuracy 0.0\n",
+ "67 epoch: avg loss 64.97802228342692 2.7456317456563313 validation accuracy 0.0 test accuracy 0.0\n",
+ "68 epoch: avg loss 64.95747393544515 2.843347436014811 validation accuracy 0.0 test accuracy 0.0\n",
+ "69 epoch: avg loss 64.95950721995035 2.7493250696818032 validation accuracy 0.0 test accuracy 0.0\n",
+ "70 epoch: avg loss 64.85206590321859 2.549251330820719 validation accuracy 0.0 test accuracy 0.0\n",
+ "71 epoch: avg loss 64.77169641265868 2.617051190185547 validation accuracy 0.0 test accuracy 0.0\n",
+ "72 epoch: avg loss 64.71160574696859 2.5421745646158853 validation accuracy 0.0 test accuracy 0.0\n",
+ "73 epoch: avg loss 64.58634377593994 2.465168037160238 validation accuracy 0.0 test accuracy 0.0\n",
+ "74 epoch: avg loss 64.57885850575765 2.361467325337728 validation accuracy 0.0 test accuracy 0.0\n",
+ "75 epoch: avg loss 64.42846132151286 2.4082817609151204 validation accuracy 0.0 test accuracy 0.0\n",
+ "76 epoch: avg loss 64.38538060048421 2.298162012990316 validation accuracy 0.0 test accuracy 0.0\n",
+ "77 epoch: avg loss 64.3141145767212 2.2881553919474285 validation accuracy 0.0 test accuracy 0.0\n",
+ "78 epoch: avg loss 64.24270374501546 2.2224127756754557 validation accuracy 0.0 test accuracy 0.0\n",
+ "79 epoch: avg loss 64.09710526682535 2.1379135854085285 validation accuracy 0.0 test accuracy 0.0\n",
+ "80 epoch: avg loss 64.07882555491129 2.261145238494873 validation accuracy 0.0 test accuracy 0.0\n",
+ "81 epoch: avg loss 64.05341328277588 2.1424091829935707 validation accuracy 0.0 test accuracy 0.0\n",
+ "82 epoch: avg loss 63.90139794260661 2.0546640805562335 validation accuracy 0.0 test accuracy 0.0\n",
+ "83 epoch: avg loss 63.886747120666506 2.0148360023498535 validation accuracy 0.0 test accuracy 0.0\n",
+ "84 epoch: avg loss 63.863132594299316 2.0213047768910726 validation accuracy 0.0 test accuracy 0.0\n",
+ "85 epoch: avg loss 63.72724614003499 1.9374477088928224 validation accuracy 0.0 test accuracy 0.0\n",
+ "86 epoch: avg loss 63.707502190653486 1.969299099858602 validation accuracy 0.0 test accuracy 0.0\n",
+ "87 epoch: avg loss 63.60284013824463 1.9574982669830323 validation accuracy 0.0 test accuracy 0.0\n",
+ "88 epoch: avg loss 63.60544111480713 2.1192375803629555 validation accuracy 0.0 test accuracy 0.0\n",
+ "89 epoch: avg loss 63.47691573232015 2.071177163950602 validation accuracy 0.0 test accuracy 0.0\n",
+ "90 epoch: avg loss 63.45331827952067 1.870301713816325 validation accuracy 0.0 test accuracy 0.0\n",
+ "91 epoch: avg loss 63.43649330393473 1.7918146341959635 validation accuracy 0.0 test accuracy 0.0\n",
+ "92 epoch: avg loss 63.37312573394775 1.8475859204610188 validation accuracy 0.0 test accuracy 0.0\n",
+ "93 epoch: avg loss 63.23665246734619 1.9314768585205078 validation accuracy 0.0 test accuracy 0.0\n",
+ "94 epoch: avg loss 63.206928472391766 1.8032865263621012 validation accuracy 0.0 test accuracy 0.0\n",
+ "95 epoch: avg loss 63.235852890523276 1.798199406305949 validation accuracy 0.0 test accuracy 0.0\n",
+ "96 epoch: avg loss 63.176236850484216 1.788236808013916 validation accuracy 0.0 test accuracy 0.0\n",
+ "97 epoch: avg loss 63.0884791112264 1.8939534503936768 validation accuracy 0.0 test accuracy 0.0\n",
+ "98 epoch: avg loss 63.16810416819254 1.7823999228159586 validation accuracy 0.0 test accuracy 0.0\n",
+ "99 epoch: avg loss 63.05663078358968 1.8215582843780518 validation accuracy 0.0 test accuracy 0.0\n",
+ "best validation accuracy 0.0 corresponding testing accuracy 0.0 last testing accuracy 0.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ei0yHxTNMEjf",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def reconstruct_image_w_label(xs, ys, vae):\n",
+ " # backward\n",
+ " xs = xs.cuda()\n",
+ " ys = ys.cuda()\n",
+ " sim_z_loc, sim_z_scale = vae.encoder_z([xs, ys])\n",
+ " zs = dist.Normal(sim_z_loc, sim_z_scale).to_event(1).sample()\n",
+ " # forward\n",
+ " zs = zs.cuda()\n",
+ " ys = ys.cuda()\n",
+ " loc = vae.decoder([zs, ys])\n",
+ " return dist.Bernoulli(loc).to_event(1).sample()\n",
+ "\n",
+ "def reconstruct_image(xs, vae):\n",
+ " # backward\n",
+ " xs = xs.cuda()\n",
+ " sim_ys = vae.encoder_y(xs)\n",
+ " sim_z_loc, sim_z_scale = vae.encoder_z([xs, sim_ys])\n",
+ " zs = dist.Normal(sim_z_loc, sim_z_scale).to_event(1).sample()\n",
+ " # forward\n",
+ " zs = zs.cuda()\n",
+ " loc = vae.decoder([zs, ys.cuda()])\n",
+ " return dist.Bernoulli(loc).to_event(1).sample()\n",
+ "\n",
+ "def convert_back(x):\n",
+ " return x.cpu().reshape(-1, 64, 64).numpy().astype(np.uint8)\n",
+ "\n",
+ "def show_images_grid(imgs_, num_images):\n",
+ " ncols = int(np.ceil(num_images**0.5))\n",
+ " nrows = int(np.ceil(num_images / ncols))\n",
+ " _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))\n",
+ " axes = axes.flatten()\n",
+ " for ax_i, ax in enumerate(axes):\n",
+ " if ax_i < num_images:\n",
+ " ax.imshow(imgs_[ax_i].reshape(64, 64), cmap='Greys_r', interpolation='nearest')\n",
+ " ax.set_xticks([])\n",
+ " ax.set_yticks([])\n",
+ " else:\n",
+ " ax.axis('off')\n",
+ "\n",
+ "def viz_images(imgs, n):\n",
+ " imgs_ = []\n",
+ " for i, x in enumerate(imgs):\n",
+ " if i > n:\n",
+ " break\n",
+ " imgs_.append(convert_back(x))\n",
+ " show_images_grid(np.array(imgs_), n)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zm-CLHjBMEjg",
+ "colab_type": "text"
+ },
+ "source": [
+ "Four original images "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "g7trYLCqMEjh",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "sup_iter = iter(data_loaders[\"sup\"])\n",
+ "xs, ys = next(sup_iter)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "vVQmHBd_MEji",
+ "colab_type": "code",
+ "outputId": "d480f485-19bc-4c12-c317-735bdafa8eb1",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 392
+ }
+ },
+ "source": [
+ "print(\"With y and x:\")\n",
+ "xs_sim1 = reconstruct_image_w_label(xs, ys, sup_vae)\n",
+ "viz_images(xs_sim1, 4)"
+ ],
+ "execution_count": 38,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "With y and x:\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWsAAAFmCAYAAACr2LumAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAB2lJREFUeJzt3c1OG0kARtGqEWve/zlZR9QsUKSI\nGJK423bf9jnLUSAWLu58bv9krrUGAMf236NvAAB/JtYAAWINECDWAAFiDRAg1gABYg0QINYAAWIN\nECDWAAFiDRAg1gABYg0QINYAAS9bvnjO+WN8BP9tn5vDk3sdY7yvtTadyz042+xs89meWz7Pes75\nPsaYV38DuGCt9fAz5WxzC1vO9tbLIFYHZ+VscyiuWQMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVA\ngFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCA\nWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBY\nAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgD\nBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAME\niDVAgFgDBIg1QIBYP5G11lhrPfpmAFcQa4CAl0ffAG7v85r+bl3POW99c4ArWNYAAWINEOAyyEld\n+0Tir1/nkggch2UNECDWfMlL/eA4xBogQKwBAsQaIECsAQK8dO8kbvlE4KXv7WV9cF+WNUCAZQ0H\n869vTPr55z3aOTfLGiDAsg575BtWPv/dVt12l+7Pf7mPfZriuVnWAAFiDRDgMkjIkT+nw6f1Xe8e\n96v7p8+yBgiwrA/uyGv6K558/Dc/fz7F+5r7sawBAizrg7KyuBVvommyrAECxBogwGUQbs7LxmA7\nyxogwLI+qDO8nMuKhv1Y1gABljU8AY9y+ixrgACxBghwGeTgfn34Wn6ykfty2eN8LGuAAMua3Vl1\nsD/LGiDAsg75vFiPdA3bmn4898G5WdYAAWINEOAyCJt46L2fSz/Lv7nU5T54DpY1QIBlHfbIN8xY\nc/d1hk9hZBvLGiDAsj6Jeywva/r+Pv/M3QfPy7IGCBBrgACXQU7mFpdDPPSGx7OsAQIsa75kUcNx\nWNYAAZb1SX23ir+7nm1NwzFZ1gABYg0Q4DLIE7r0mSIuf8CxWdYAAZb1k7OoocGyBggQa4AAsQYI\nEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgIC55V/BnnO+jzF8EhC7Wms9/Ew529zClrO9\n9VP33sfHOn/b+H1gjDFex8eZOgJnmz1tPtubljUA9+GaNUCAWAMEiDVAgFgDBIg1QIBYAwSINUCA\nWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBY\nAwSINUCAWAMEiDVAwMuWL55z/hgfwX/b5+bw5F7HGO9rrU3ncg/ONjvbfLbnWuvqv33O+T7GmFd/\nA7hgrfXwM+VscwtbzvbWyyBWB2flbHMorlkDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSI\nNUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBLw8\n+gac3Vrrt/8253zALQHKLGuAAMt6Z5eW9J/+jKUN/IllDRAg1gABYg0QINYAAZ5g3MHfPKkIsIVl\nDRBgWW9gUQP3YlkDBIg1QIDLIMDD3eKS4tneGWxZAwRY1lfYewX8+v3OtgbgK7d+gv5sn8FjWQME\nWNbAQ/y6dO/xMtj6I1jLGiBArAECxPoKc86bPYxaa3lnJPAbsQYIEOsNbrmw4Zn4XfozsQYI8NK9\nHfzNIvjb69DWBc/s0vnf6zmc+u+WZQ0QINYAAS6D3Ml3D8G8VA++9vl359I7Eb/7Hapf/vjJsgYI\nsKwP4Cz/54d7uPT78gy/Q5Y1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgD\nBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAME\niDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSI\nNUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1\nQIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVA\ngFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QMBca13/xXO+\njzHmfjcHxlhrPfxMOdvcwpaz/bLx734fH+v8beP3gTHGeB0fZ+oInG32tPlsb1rWANyHa9YAAWIN\nECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINEPA//boH4H1pBIAAAAAA\nSUVORK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "fEvI9XGTMEjk",
+ "colab_type": "code",
+ "outputId": "b7db84ee-5957-4f0c-c35f-e9e7f42c2ab6",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 392
+ }
+ },
+ "source": [
+ "print(\"With just x:\")\n",
+ "# backward\n",
+ "xs_sim2 = reconstruct_image(xs, sup_vae)\n",
+ "viz_images(xs_sim2, 4)"
+ ],
+ "execution_count": 39,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "With just x:\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWsAAAFmCAYAAACr2LumAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAACJNJREFUeJzt3c1u29YCRlGy8Djv/5wZFz53YBgw\nXF1aFimRm1xr1iZ2hZTe+XT0N48xJgCO7Z+9bwAAPxNrgACxBggQa4AAsQYIEGuAALEGCBBrgACx\nBggQa4AAsQYIEGuAALEGCBBrgIC3NV88z/O/00fw/25zc7i4P9M0vY8xVl2XW3Bts7HV1/a85v2s\n53l+n6ZpfvgbwA1jjN2vKdc2z7Dm2l57DGJ1cFaubQ7FmTVAgFgDBIg1QIBYAwSINUCAWAMEiDVA\ngFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCA\nWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBJwq1mOMaYyx980A2NypYg1wVm9734AtzfO8\n900AeArLGiBArAECTnUMAhVLD4Q7zuMWyxogwLKGF7rnqaVff4+VzSfLGiBArOGF5nm2lnmIWAME\niDVAgFgDBIg1QICn7sEL3fPUPQ9AcotlDRBgWcMLWc08yrIGCBBrgACxBggQa4AAsQYIEGuAALEG\nCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsT6JMcZdHxkFNIk1QICP9Qq6taA/Py7q\n89d8fBSci2UNEGBZn8T3tf31n61s6LOsAQLEOmieZ2sZLkasAQLEGiDAA4xBv33xi6fzQZ9lDRAg\n1hfiJenQJdYAAc6sL8gLZqDHsgYIEGuAAMcgF+ToA3osa4AAy/qCPMAIPZY1QIBlvbHvLzqxXDma\n37ww6tHrd+nTjHiMZQ0QINYAAY5BNrB0t9LdQfaw1XvAbHn9evfHdSxrgADLegdLq2dpdXjHPL7a\n+3pYWsqf/87TRLdjWQMEWNYbe2Q9/PbMe0vOEY9t7/V8y/drZmk9u662Y1kDBIg1QIBjkA1s+Sqv\ne35tS+6mHs8Rjj5uXRf/78jMNfQaljVAgGW9o6VFcuvBm2d4xftE8Dtf/5xfvbLvuSbZh2UNEGBZ\nH9yeK+s7L3B4vWfew/L/sMWyBggQa4AAxyAhr3rQ8Te8ArLH/6smyxogwLLmIUda91ew9D4c7t1c\ng2UNEGBZs4r3LX6tpfeO5twsa4AAsQYIcAwSdKSn8B3hNsAVWNYAAZY1d7v1PiVHWvlwZpY1QIBl\nzY+WnhrmBRnwGpY1QIBYAwQ4Bglb+lDTLb//0vd0/AGvYVkDBFjWJ3PPg4G3fv/SA4XWM+zPsgYI\nsKwPbmkN/9bS161dz951D57LsgYIEGuAAMcgK7zi1XtHP1LwCkZ4DcsaIMCyfsBea/KIK/ZItwXO\nzLIGCLCsH7DXmrRi4bosa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQ\na4AAsQYIEGuAALEGCBBrgID583P9HvrieX6fpsnHl7CpMcbu15Rrm2dYc22v/Viv9+ljnf9d+X1g\nmqbpz/RxTR2Ba5strb62Vy1rAF7DmTVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSI\nNUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1\nQMDbmi+e5/nf6SP4f7e5OVzcn2ma3scYq67LLbi22djqa3seYzz8X5/n+X2apvnhbwA3jDF2v6Zc\n2zzDmmt77TGI1cFZubY5FGfWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWIN\nECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1jsaY0xrPl0euA6xBggQa4CAt71vwBU5+gB+\ny7IGCLCsX8SaZo17rp95nl9wS9iLZQ0QYFk/2T2L6OvvsY641+e14l7b4z7/7Ao/d5Y1QIBYAwQ4\nBoGo0l34I6gfF1nWAAGW9ZM8+re4tcQtt66Hz2vFNXMNljVAgGW9sbXnYtYR93Kt/Nfae7TTdNw/\nV8saIECsAQIcgwCH8+pjiaMefXxlWQMEWNYb2PLJ9t+/V+FvfNjKrZ+l+otZtmJZAwRY1hv4un6t\nAOAZLGuAALEGCHAMsrGt3xD+1vfxoCNso/SzZFkDBFjWT7K0sH+zvpfeba20CmDJMx+YP8vPiWUN\nEGBZP9nS3+pLv/b9vYrv/Tq4sjP/bFjWAAFiDRDgGOSgznx3DrZ2hZ8XyxogwLIGUq6wom+xrAEC\nLGtgd/e8iOzqLGuAALEGCHAMAuzOx9n9zLIGCLCsgd1Z0j+zrAECxBogQKwBAsQaIECsAQLEGiBA\nrAECxBogwItiDs67kAHTZFkDJIh10Bjj5uIGzkusAQLEGiDAA4wHd+vBREcgcD2WNUCAWAMEiDVA\ngDPrIC+K4eyWHpe56vVvWQMEiDVAgGMQ4HC+HnV4quoHyxogwLIGDu2qDyh+Z1kDBFjWwGWV3i/e\nsgYIEGuAAMcgwKnd89S/ox59fGVZAwSINXBq8zwnlvNPxBogwJk1cDnFpW1ZAwSINUCAYxDgEopH\nH19Z1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWIN\nECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0Q\nINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg\n1gAB8xjj8S+e5/dpmubtbg5M0xhj92vKtc0zrLm231b+t9+nj3X+d+X3gWmapj/TxzV1BK5ttrT6\n2l61rAF4DWfWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRDw\nP3rLqigPu5+WAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "H_OtIYKxdNp-",
+ "colab_type": "code",
+ "outputId": "edf1ceff-99c8-421c-86d9-b4a3766e64b5",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ }
+ },
+ "source": [
+ "xs_sim2.shape"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "torch.Size([256, 4096])"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 44
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "tCZoEdLh0rET",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file