Multiview Deep CCA Extensions#

This script showcases how to train extensions of Deep Canonical Correlation Analysis (Deep CCA) that can handle more than two representations of data, using CCA-Zoo’s functionalities.

Features: - Deep MCCA (Multiset CCA) - Deep GCCA (Generalized CCA) - Deep TCCA (Tied CCA)

import lightning.pytorch as pl
from cca_zoo.deep import DCCA, DTCCA, architectures, objectives
from examples import example_mnist_data

Data Preparation#

Here, we use a segmented MNIST dataset as an example of multiview data.

LATENT_DIMS = 2
EPOCHS = 10
N_TRAIN = 500
N_VAL = 100

train_loader, val_loader, train_labels, val_labels = example_mnist_data(N_TRAIN, N_VAL)

encoder_1 = architectures.Encoder(latent_dimensions=LATENT_DIMS, feature_size=392)
encoder_2 = architectures.Encoder(latent_dimensions=LATENT_DIMS, feature_size=392)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
 94%|█████████▍| 9338880/9912422 [00:00<00:00, 93278359.67it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 94140861.86it/s]
Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 175304911.47it/s]
Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 32084412.63it/s]
Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
100%|██████████| 4542/4542 [00:00<00:00, 33305120.22it/s]
Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw

Deep MCCA (Multiset CCA)#

A multiview extension of CCA, aiming to find latent spaces that are maximally correlated across multiple representations.

dcca_mcca = DCCA(
    latent_dimensions=LATENT_DIMS,
    encoders=[encoder_1, encoder_2],
    objective=objectives._MCCALoss,
)
trainer_mcca = pl.Trainer(
    max_epochs=EPOCHS,
    enable_checkpointing=False,
    enable_model_summary=False,
    enable_progress_bar=False,
)
trainer_mcca.fit(dcca_mcca, train_loader, val_loader)
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

Deep GCCA (Generalized CCA)#

A method that finds projections of multiple representations such that the variance explained by the canonical components is maximized.

dcca_gcca = DCCA(
    latent_dimensions=LATENT_DIMS,
    encoders=[encoder_1, encoder_2],
    objective=objectives._GCCALoss,
)
trainer_gcca = pl.Trainer(
    max_epochs=EPOCHS,
    enable_checkpointing=False,
    enable_model_summary=False,
    enable_progress_bar=False,
)
trainer_gcca.fit(dcca_gcca, train_loader, val_loader)

Deep TCCA (Tied CCA)#

An approach where representations share the same weight parameters during training.

dcca_tcca = DTCCA(latent_dimensions=LATENT_DIMS, encoders=[encoder_1, encoder_2])
trainer_tcca = pl.Trainer(
    max_epochs=EPOCHS,
    enable_checkpointing=False,
    enable_model_summary=False,
    enable_progress_bar=False,
)
trainer_tcca.fit(dcca_tcca, train_loader, val_loader)

Total running time of the script: (0 minutes 7.327 seconds)

Gallery generated by Sphinx-Gallery