Sparse CCA MethodsΒΆ

This script shows how regularised methods can be used to extract sparse solutions to the CCA problem

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from cca_zoo.data import generate_covariance_data
from cca_zoo.model_selection import GridSearchCV
from cca_zoo.models import PMD, SCCA, ElasticCCA, CCA, PLS, SCCA_ADMM, SpanCCA
np.random.seed(42)
n = 200
p = 100
q = 100
view_1_sparsity = 0.1
view_2_sparsity = 0.1
latent_dims = 1

(X, Y), (tx, ty) = generate_covariance_data(
    n,
    view_features=[p, q],
    latent_dims=latent_dims,
    view_sparsity=[view_1_sparsity, view_2_sparsity],
    correlation=[0.9],
)
tx /= np.sqrt(n)
ty /= np.sqrt(n)
def plot_true_weights_coloured(ax, weights, true_weights, title=""):
    ind = np.arange(len(true_weights))
    mask = np.squeeze(true_weights == 0)
    ax.scatter(ind[~mask], weights[~mask], c="b")
    ax.scatter(ind[mask], weights[mask], c="r")
    ax.set_title(title)


def plot_model_weights(wx, wy, tx, ty):
    fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
    plot_true_weights_coloured(axs[0, 0], tx, tx, title="true x weights")
    plot_true_weights_coloured(axs[0, 1], ty, ty, title="true y weights")
    plot_true_weights_coloured(axs[1, 0], wx, tx, title="model x weights")
    plot_true_weights_coloured(axs[1, 1], wy, ty, title="model y weights")
    plt.tight_layout()
    plt.show()
cca = CCA().fit([X, Y])
plot_model_weights(cca.weights[0], cca.weights[1], tx, ty)
true x weights, true y weights, model x weights, model y weights
pls = PLS().fit([X, Y])
plot_model_weights(pls.weights[0], pls.weights[1], tx, ty)
true x weights, true y weights, model x weights, model y weights
pmd = PMD(c=[2, 2]).fit([X, Y])
plot_model_weights(pmd.weights[0], pmd.weights[1], tx, ty)
true x weights, true y weights, model x weights, model y weights
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(pmd.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Objective Convergence

Out:

Text(0.5, 23.52222222222222, '#iterations')
c1 = [1, 3, 7, 9]
c2 = [1, 3, 7, 9]
param_grid = {"c": [c1, c2]}
pmd = GridSearchCV(PMD(), param_grid=param_grid, cv=3, verbose=True).fit([X, Y])

Out:

Fitting 3 folds for each of 16 candidates, totalling 48 fits
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")
pd.DataFrame(pmd.cv_results_)
mean_fit_time std_fit_time mean_score_time std_score_time param_c params split0_test_score split1_test_score split2_test_score mean_test_score std_test_score rank_test_score
0 0.040332 0.010897 0.001755 0.000008 [1, 1] {'c': [1, 1]} 0.052058 -0.033554 -0.071958 -0.017818 0.051838 13
1 0.045717 0.010863 0.001674 0.000027 [1, 3] {'c': [1, 3]} 0.153789 0.360482 0.036318 0.183530 0.134000 2
2 0.049288 0.001907 0.001821 0.000027 [1, 7] {'c': [1, 7]} 0.160803 0.331418 -0.013043 0.159726 0.140628 4
3 0.040549 0.005486 0.001686 0.000068 [1, 9] {'c': [1, 9]} 0.164070 0.330148 -0.001263 0.164318 0.135298 3
4 0.046483 0.010384 0.001790 0.000027 [3, 1] {'c': [3, 1]} 0.073034 -0.208602 -0.024809 -0.053459 0.116749 15
5 0.192965 0.073421 0.001773 0.000040 [3, 3] {'c': [3, 3]} 0.046186 0.600012 0.414416 0.353538 0.230160 1
6 0.263225 0.042769 0.001782 0.000077 [3, 7] {'c': [3, 7]} 0.014920 -0.091810 0.163879 0.028996 0.104858 10
7 0.264951 0.021462 0.001764 0.000044 [3, 9] {'c': [3, 9]} -0.060565 0.275957 0.113212 0.109535 0.137409 5
8 0.040816 0.007507 0.001681 0.000025 [7, 1] {'c': [7, 1]} 0.024604 -0.163847 -0.104591 -0.081278 0.078681 16
9 0.218284 0.073807 0.001795 0.000046 [7, 3] {'c': [7, 3]} 0.049719 0.135290 -0.142186 0.014275 0.116019 11
10 0.204761 0.030259 0.001771 0.000042 [7, 7] {'c': [7, 7]} 0.073814 0.044416 0.044439 0.054223 0.013853 7
11 0.112398 0.004850 0.001760 0.000063 [7, 9] {'c': [7, 9]} 0.068125 0.066507 0.044814 0.059815 0.010628 6
12 0.040103 0.009469 0.001669 0.000003 [9, 1] {'c': [9, 1]} 0.052391 -0.073953 -0.130107 -0.050556 0.076319 14
13 0.171223 0.064289 0.001754 0.000020 [9, 3] {'c': [9, 3]} 0.043383 0.149690 -0.219036 -0.008654 0.154964 12
14 0.180746 0.089290 0.001752 0.000034 [9, 7] {'c': [9, 7]} 0.075432 0.019489 0.052429 0.049117 0.022958 8
15 0.049870 0.008860 0.001643 0.000011 [9, 9] {'c': [9, 9]} 0.055716 0.037696 0.041400 0.044937 0.007770 9


scca = SCCA(c=[1e-3, 1e-3]).fit([X, Y])
plot_model_weights(scca.weights[0], scca.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")

Text(0.5, 23.52222222222222, '#iterations')
scca_pos = SCCA(c=[1e-3, 1e-3], positive=[True, True]).fit([X, Y])
plot_model_weights(scca_pos.weights[0], scca_pos.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca_pos.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")

Text(0.5, 23.52222222222222, '#iterations')
elasticcca = ElasticCCA(c=[10000, 10000], l1_ratio=[0.000001, 0.000001]).fit([X, Y])
plot_model_weights(elasticcca.weights[0], elasticcca.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(elasticcca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")

Text(0.5, 23.52222222222222, '#iterations')
scca_admm = SCCA_ADMM(c=[1e-3, 1e-3]).fit([X, Y])
plot_model_weights(scca_admm.weights[0], scca_admm.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca_admm.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")

Text(0.5, 23.52222222222222, '#iterations')
spancca = SpanCCA(c=[10, 10], max_iter=2000, rank=20).fit([X, Y])
plot_model_weights(spancca.weights[0], spancca.weights[1], tx, ty)

# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(spancca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
  • true x weights, true y weights, model x weights, model y weights
  • Objective Convergence

Out:

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.1/cca_zoo/models/iterative.py:103: UserWarning: Inner loop 0 did not converge or converged to nans
  warnings.warn(f"Inner loop {k} did not converge or converged to nans")

Text(0.5, 23.52222222222222, '#iterations')

Total running time of the script: ( 0 minutes 16.919 seconds)

Gallery generated by Sphinx-Gallery