Note
Click here to download the full example code
Sparse CCA MethodsΒΆ
This script shows how regularised methods can be used to extract sparse solutions to the CCA problem
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from cca_zoo.data import generate_covariance_data
from cca_zoo.model_selection import GridSearchCV
from cca_zoo.models import PMD, SCCA, ElasticCCA, CCA, PLS, SCCA_ADMM, SpanCCA
np.random.seed(42)
n = 200
p = 100
q = 100
view_1_sparsity = 0.1
view_2_sparsity = 0.1
latent_dims = 1
(X, Y), (tx, ty) = generate_covariance_data(
n,
view_features=[p, q],
latent_dims=latent_dims,
view_sparsity=[view_1_sparsity, view_2_sparsity],
correlation=[0.9],
)
tx /= np.sqrt(n)
ty /= np.sqrt(n)
def plot_true_weights_coloured(ax, weights, true_weights, title=""):
ind = np.arange(len(true_weights))
mask = np.squeeze(true_weights == 0)
ax.scatter(ind[~mask], weights[~mask], c="b")
ax.scatter(ind[mask], weights[mask], c="r")
ax.set_title(title)
def plot_model_weights(wx, wy, tx, ty):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
plot_true_weights_coloured(axs[0, 0], tx, tx, title="true x weights")
plot_true_weights_coloured(axs[0, 1], ty, ty, title="true y weights")
plot_true_weights_coloured(axs[1, 0], wx, tx, title="model x weights")
plot_true_weights_coloured(axs[1, 1], wy, ty, title="model y weights")
plt.tight_layout()
plt.show()
cca = CCA().fit([X, Y])
plot_model_weights(cca.weights[0], cca.weights[1], tx, ty)
pls = PLS().fit([X, Y])
plot_model_weights(pls.weights[0], pls.weights[1], tx, ty)
pmd = PMD(c=[0.5, 0.5]).fit([X, Y])
plot_model_weights(pmd.weights[0], pmd.weights[1], tx, ty)
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(pmd.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Out:
Text(0.5, 23.52222222222222, '#iterations')
c1 = [0.1, 0.3, 0.7, 0.9]
c2 = [0.1, 0.3, 0.7, 0.9]
param_grid = {"c": [c1, c2]}
pmd = GridSearchCV(PMD(), param_grid=param_grid, cv=3, verbose=True).fit([X, Y])
Out:
Fitting 3 folds for each of 16 candidates, totalling 48 fits
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
pd.DataFrame(pmd.cv_results_)
scca = SCCA(c=[1e-3, 1e-3]).fit([X, Y])
plot_model_weights(scca.weights[0], scca.weights[1], tx, ty)
# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Out:
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
Text(0.5, 23.52222222222222, '#iterations')
scca_pos = SCCA(c=[1e-3, 1e-3], positive=[True, True]).fit([X, Y])
plot_model_weights(scca_pos.weights[0], scca_pos.weights[1], tx, ty)
# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca_pos.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Out:
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
Text(0.5, 23.52222222222222, '#iterations')
elasticcca = ElasticCCA(c=[10000, 10000], l1_ratio=[0.000001, 0.000001]).fit([X, Y])
plot_model_weights(elasticcca.weights[0], elasticcca.weights[1], tx, ty)
# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(elasticcca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Out:
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
Text(0.5, 23.52222222222222, '#iterations')
scca_admm = SCCA_ADMM(c=[1e-3, 1e-3]).fit([X, Y])
plot_model_weights(scca_admm.weights[0], scca_admm.weights[1], tx, ty)
# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(scca_admm.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Out:
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
Text(0.5, 23.52222222222222, '#iterations')
spancca = SpanCCA(c=[10, 10], max_iter=2000, rank=20).fit([X, Y])
plot_model_weights(spancca.weights[0], spancca.weights[1], tx, ty)
# Convergence
plt.figure()
plt.title("Objective Convergence")
plt.plot(np.array(spancca.track[0]["objective"]).T)
plt.ylabel("Objective")
plt.xlabel("#iterations")
Out:
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/checkouts/v1.10.4/cca_zoo/models/iterative.py:107: UserWarning: Inner loop 0 not converged. Increase number of iterations.
f"Inner loop {k} not converged. Increase number of iterations."
Text(0.5, 23.52222222222222, '#iterations')
Total running time of the script: ( 0 minutes 16.344 seconds)