Sparse CCA MethodsΒΆ

This script shows how regularised methods can be used to extract sparse solutions to the CCA problem

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from cca_zoo.data import generate_covariance_data
from cca_zoo.model_selection import GridSearchCV
from cca_zoo.models import PMD, SCCA, ElasticCCA, CCA, PLS, SCCA_ADMM, SpanCCA
np.random.seed(42)
n = 200
p = 100
q = 100
view_1_sparsity = 0.1
view_2_sparsity = 0.1
latent_dims = 1

(X, Y), (tx, ty) = generate_covariance_data(
    n,
    view_features=[p, q],
    latent_dims=latent_dims,
    view_sparsity=[view_1_sparsity, view_2_sparsity],
    correlation=[0.9],
)
tx /= np.sqrt(n)
ty /= np.sqrt(n)
def plot_true_weights_coloured(ax, weights, true_weights, title=""):
    ind = np.arange(len(true_weights))
    mask = np.squeeze(true_weights == 0)
    ax.scatter(ind[~mask], weights[~mask], c="b")
    ax.scatter(ind[mask], weights[mask], c="r")
    ax.set_title(title)


def plot_model_weights(wx, wy, tx, ty):
    fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
    plot_true_weights_coloured(axs[0, 0], tx, tx, title="true x weights")
    plot_true_weights_coloured(axs[0, 1], ty, ty, title="true y weights")
    plot_true_weights_coloured(axs[1, 0], wx, tx, title="model x weights")
    plot_true_weights_coloured(axs[1, 1], wy, ty, title="model y weights")
    plt.tight_layout()
    plt.show()
cca = CCA().fit([X, Y])
plot_model_weights(cca.weights[0], cca.weights[1], tx, ty)
true x weights, true y weights, model x weights, model y weights
pls = PLS().fit([X, Y])
plot_model_weights(pls.weights[0], pls.weights[1], tx, ty)
true x weights, true y weights, model x weights, model y weights
pmd = PMD(c=[0.5, 0.5]).fit([X, Y])
plot_model_weights(pmd.weights[0], pmd.weights[1], tx, ty)
true x weights, true y weights, model x weights, model y weights
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(pmd.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Objective Convergence

Out:

Text(0.5, 23.52222222222222, '#iterations')
c1 = [0.1, 0.3, 0.7, 0.9]
c2 = [0.1, 0.3, 0.7, 0.9]
param_grid = {"c": [c1, c2]}
pmd = GridSearchCV(PMD(), param_grid=param_grid, cv=3, verbose=True).fit([X, Y])

Out:

Fitting 3 folds for each of 16 candidates, totalling 48 fits
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."
pd.DataFrame(pmd.cv_results_)
mean_fit_time std_fit_time mean_score_time std_score_time param_c params split0_test_score split1_test_score split2_test_score mean_test_score std_test_score rank_test_score
0 0.016504 0.000436 0.001772 0.000018 [0.1, 0.1] {'c': [0.1, 0.1]} 0.052058 -0.033554 -0.071958 -0.017818 0.051838 13
1 0.021094 0.000017 0.001780 0.000054 [0.1, 0.3] {'c': [0.1, 0.3]} 0.153788 0.360482 0.036318 0.183529 0.134000 2
2 0.021084 0.000070 0.002274 0.000801 [0.1, 0.7] {'c': [0.1, 0.7]} 0.160803 0.331418 -0.013043 0.159726 0.140628 4
3 0.014632 0.000081 0.001692 0.000031 [0.1, 0.9] {'c': [0.1, 0.9]} 0.164070 0.330148 -0.001263 0.164318 0.135298 3
4 0.021207 0.000091 0.001814 0.000047 [0.3, 0.1] {'c': [0.3, 0.1]} 0.073034 -0.208602 -0.024809 -0.053459 0.116748 15
5 0.211220 0.088018 0.001860 0.000073 [0.3, 0.3] {'c': [0.3, 0.3]} 0.046186 0.600011 0.414416 0.353538 0.230160 1
6 0.291653 0.065017 0.001915 0.000018 [0.3, 0.7] {'c': [0.3, 0.7]} 0.014920 -0.091814 0.163878 0.028995 0.104859 10
7 0.287788 0.022791 0.001868 0.000086 [0.3, 0.9] {'c': [0.3, 0.9]} -0.060564 0.275958 0.113212 0.109535 0.137409 5
8 0.021968 0.000898 0.002162 0.000487 [0.7, 0.1] {'c': [0.7, 0.1]} 0.024604 -0.163848 -0.104591 -0.081278 0.078681 16
9 0.242157 0.083979 0.001883 0.000031 [0.7, 0.3] {'c': [0.7, 0.3]} 0.049719 0.135290 -0.142186 0.014274 0.116019 11
10 0.228890 0.034498 0.001878 0.000038 [0.7, 0.7] {'c': [0.7, 0.7]} 0.073814 0.044416 0.044439 0.054223 0.013853 7
11 0.111682 0.006369 0.001853 0.000035 [0.7, 0.9] {'c': [0.7, 0.9]} 0.068124 0.066506 0.044814 0.059815 0.010628 6
12 0.015750 0.000097 0.001756 0.000034 [0.9, 0.1] {'c': [0.9, 0.1]} 0.052391 -0.073953 -0.130107 -0.050556 0.076319 14
13 0.175317 0.070153 0.001858 0.000020 [0.9, 0.3] {'c': [0.9, 0.3]} 0.043383 0.149690 -0.219035 -0.008654 0.154963 12
14 0.185202 0.095532 0.001744 0.000219 [0.9, 0.7] {'c': [0.9, 0.7]} 0.075432 0.019472 0.052429 0.049111 0.022966 8
15 0.014733 0.001436 0.001726 0.000017 [0.9, 0.9] {'c': [0.9, 0.9]} 0.055706 0.037635 0.041380 0.044907 0.007787 9


scca = SCCA(c=[1e-3, 1e-3]).fit([X, Y])
plot_model_weights(scca.weights[0], scca.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."

Text(0.5, 23.52222222222222, '#iterations')
scca_pos = SCCA(c=[1e-3, 1e-3], positive=[True, True]).fit([X, Y])
plot_model_weights(scca_pos.weights[0], scca_pos.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca_pos.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."

Text(0.5, 23.52222222222222, '#iterations')
elasticcca = ElasticCCA(c=[10000, 10000], l1_ratio=[0.000001, 0.000001]).fit([X, Y])
plot_model_weights(elasticcca.weights[0], elasticcca.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(elasticcca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."

Text(0.5, 23.52222222222222, '#iterations')
scca_admm = SCCA_ADMM(c=[1e-3, 1e-3]).fit([X, Y])
plot_model_weights(scca_admm.weights[0], scca_admm.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca_admm.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."

Text(0.5, 23.52222222222222, '#iterations')
spancca = SpanCCA(c=[10, 10], max_iter=2000, rank=20).fit([X, Y])
plot_model_weights(spancca.weights[0], spancca.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(spancca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.6/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
  f"Inner loop {k} not converged. Increase number of iterations."

Text(0.5, 23.52222222222222, '#iterations')

Total running time of the script: ( 0 minutes 16.417 seconds)

Gallery generated by Sphinx-Gallery