{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Deep Variational CCA and Deep Canonically Correlated Autoencoders\n\nThis example demonstrates multiview models which can reconstruct their inputs\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import Subset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from cca_zoo.data import Noisy_MNIST_Dataset\nfrom cca_zoo.deepmodels import (\n    CCALightning,\n    get_dataloaders,\n    architectures,\n    DCCAE,\n    DVCCA,\n)\n\n\ndef plot_reconstruction(model, dataset, idx):\n    (x, y), _ = dataset[idx]\n    recon_x, recon_y = model.recon(x, y)\n    if isinstance(recon_x, list):\n        recon_x = recon_x[0]\n        recon_y = recon_y[0]\n    recon_x = recon_x.detach().numpy()\n    recon_y = recon_y.detach().numpy()\n    fig, ax = plt.subplots(ncols=4)\n    ax[0].set_title(\"Original View 1\")\n    ax[1].set_title(\"Original View 2\")\n    ax[2].set_title(\"Reconstruction View 1\")\n    ax[3].set_title(\"Reconstruction View 2\")\n    ax[0].imshow(x.detach().numpy().reshape((28, 28)))\n    ax[1].imshow(y.detach().numpy().reshape((28, 28)))\n    ax[2].imshow(recon_x.reshape((28, 28)))\n    ax[3].imshow(recon_y.reshape((28, 28)))\n\n\nn_train = 500\nn_val = 100\ntrain_dataset = Noisy_MNIST_Dataset(mnist_type=\"MNIST\", train=True, flatten=True)\nval_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))\ntrain_dataset = Subset(train_dataset, np.arange(n_train))\ntrain_loader, val_loader = get_dataloaders(train_dataset, val_dataset)\n\n# The number of latent dimensions across models\nlatent_dims = 2\n# number of epochs for deep models\nepochs = 50\n\nencoder_1 = architectures.Encoder(\n    latent_dims=latent_dims, feature_size=784, variational=True\n)\nencoder_2 = architectures.Encoder(\n    latent_dims=latent_dims, feature_size=784, variational=True\n)\ndecoder_1 = architectures.Decoder(\n    latent_dims=latent_dims, feature_size=784, norm_output=True\n)\ndecoder_2 = architectures.Decoder(\n    latent_dims=latent_dims, feature_size=784, norm_output=True\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Deep VCCA\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dcca = DVCCA(\n    latent_dims=latent_dims,\n    encoders=[encoder_1, encoder_2],\n    decoders=[decoder_1, decoder_2],\n)\ndcca = CCALightning(dcca)\ntrainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)\ntrainer.fit(dcca, train_loader, val_loader)\nplot_reconstruction(dcca.model, train_dataset, 0)\nplt.suptitle(\"DVCCA\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Deep VCCA (private)\nWe need to add additional private encoders and change (double) the dimensionality of the decoders.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "private_encoder_1 = architectures.Encoder(\n    latent_dims=latent_dims, feature_size=784, variational=True\n)\nprivate_encoder_2 = architectures.Encoder(\n    latent_dims=latent_dims, feature_size=784, variational=True\n)\nprivate_decoder_1 = architectures.Decoder(latent_dims=2 * latent_dims, feature_size=784)\nprivate_decoder_2 = architectures.Decoder(latent_dims=2 * latent_dims, feature_size=784)\ndcca = DVCCA(\n    latent_dims=latent_dims,\n    encoders=[encoder_1, encoder_2],\n    decoders=[private_decoder_1, private_decoder_2],\n    private_encoders=[private_encoder_1, private_encoder_2],\n)\ndcca = CCALightning(dcca)\ntrainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)\ntrainer.fit(dcca, train_loader, val_loader)\nplot_reconstruction(dcca.model, train_dataset, 0)\nplt.suptitle(\"DVCCA Private\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "DCCAE\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\nencoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\ndcca = DCCAE(\n    latent_dims=latent_dims,\n    encoders=[encoder_1, encoder_2],\n    decoders=[decoder_1, decoder_2],\n)\ndcca = CCALightning(dcca)\ntrainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)\ntrainer.fit(dcca, train_loader, val_loader)\nplot_reconstruction(dcca.model, train_dataset, 0)\nplt.suptitle(\"DCCAE\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}