{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Sparse CCA Methods\n\nThis script shows how regularised methods can be used to extract sparse solutions to the CCA problem\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nfrom cca_zoo.data import generate_covariance_data\nfrom cca_zoo.model_selection import GridSearchCV\nfrom cca_zoo.models import PMD, SCCA, ElasticCCA, CCA, PLS, SCCA_ADMM, SpanCCA"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "np.random.seed(42)\nn = 200\np = 100\nq = 100\nview_1_sparsity = 0.1\nview_2_sparsity = 0.1\nlatent_dims = 1\n\n(X, Y), (tx, ty) = generate_covariance_data(\n    n,\n    view_features=[p, q],\n    latent_dims=latent_dims,\n    view_sparsity=[view_1_sparsity, view_2_sparsity],\n    correlation=[0.9],\n)\ntx /= np.sqrt(n)\nty /= np.sqrt(n)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def plot_true_weights_coloured(ax, weights, true_weights, title=\"\"):\n    ind = np.arange(len(true_weights))\n    mask = np.squeeze(true_weights == 0)\n    ax.scatter(ind[~mask], weights[~mask], c=\"b\")\n    ax.scatter(ind[mask], weights[mask], c=\"r\")\n    ax.set_title(title)\n\n\ndef plot_model_weights(wx, wy, tx, ty):\n    fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)\n    plot_true_weights_coloured(axs[0, 0], tx, tx, title=\"true x weights\")\n    plot_true_weights_coloured(axs[0, 1], ty, ty, title=\"true y weights\")\n    plot_true_weights_coloured(axs[1, 0], wx, tx, title=\"model x weights\")\n    plot_true_weights_coloured(axs[1, 1], wy, ty, title=\"model y weights\")\n    plt.tight_layout()\n    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cca = CCA().fit([X, Y])\nplot_model_weights(cca.weights[0], cca.weights[1], tx, ty)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pls = PLS().fit([X, Y])\nplot_model_weights(pls.weights[0], pls.weights[1], tx, ty)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pmd = PMD(c=[0.5, 0.5]).fit([X, Y])\nplot_model_weights(pmd.weights[0], pmd.weights[1], tx, ty)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.figure()\nplt.title(\"Objective Convergence\")\nplt.plot(np.array(pmd.track[0][\"objective\"]).T)\nplt.ylabel(\"Objective\")\nplt.xlabel(\"#iterations\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "c1 = [0.1, 0.3, 0.7, 0.9]\nc2 = [0.1, 0.3, 0.7, 0.9]\nparam_grid = {\"c\": [c1, c2]}\npmd = GridSearchCV(PMD(), param_grid=param_grid, cv=3, verbose=True).fit([X, Y])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pd.DataFrame(pmd.cv_results_)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "scca = SCCA(c=[1e-3, 1e-3]).fit([X, Y])\nplot_model_weights(scca.weights[0], scca.weights[1], tx, ty)\n\n# Convergence\nplt.figure()\nplt.title(\"Objective Convergence\")\nplt.plot(np.array(scca.track[0][\"objective\"]).T)\nplt.ylabel(\"Objective\")\nplt.xlabel(\"#iterations\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "scca_pos = SCCA(c=[1e-3, 1e-3], positive=[True, True]).fit([X, Y])\nplot_model_weights(scca_pos.weights[0], scca_pos.weights[1], tx, ty)\n\n# Convergence\nplt.figure()\nplt.title(\"Objective Convergence\")\nplt.plot(np.array(scca_pos.track[0][\"objective\"]).T)\nplt.ylabel(\"Objective\")\nplt.xlabel(\"#iterations\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "elasticcca = ElasticCCA(c=[10000, 10000], l1_ratio=[0.000001, 0.000001]).fit([X, Y])\nplot_model_weights(elasticcca.weights[0], elasticcca.weights[1], tx, ty)\n\n# Convergence\nplt.figure()\nplt.title(\"Objective Convergence\")\nplt.plot(np.array(elasticcca.track[0][\"objective\"]).T)\nplt.ylabel(\"Objective\")\nplt.xlabel(\"#iterations\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "scca_admm = SCCA_ADMM(c=[1e-3, 1e-3]).fit([X, Y])\nplot_model_weights(scca_admm.weights[0], scca_admm.weights[1], tx, ty)\n\n# Convergence\nplt.figure()\nplt.title(\"Objective Convergence\")\nplt.plot(np.array(scca_admm.track[0][\"objective\"]).T)\nplt.ylabel(\"Objective\")\nplt.xlabel(\"#iterations\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "spancca = SpanCCA(c=[10, 10], max_iter=2000, rank=20).fit([X, Y])\nplot_model_weights(spancca.weights[0], spancca.weights[1], tx, ty)\n\n# Convergence\nplt.figure()\nplt.title(\"Objective Convergence\")\nplt.plot(np.array(spancca.track[0][\"objective\"]).T)\nplt.ylabel(\"Objective\")\nplt.xlabel(\"#iterations\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}