Note
Go to the end to download the full example code
Multiview Deep CCA Extensions#
This script showcases how to train extensions of Deep Canonical Correlation Analysis (Deep CCA) that can handle more than two representations of data, using CCA-Zoo’s functionalities.
Features: - Deep MCCA (Multiset CCA) - Deep GCCA (Generalized CCA) - Deep TCCA (Tied CCA)
import lightning.pytorch as pl
from cca_zoo.deep import DCCA, DTCCA, architectures, objectives
from examples import example_mnist_data
Data Preparation#
Here, we use a segmented MNIST dataset as an example of multiview data.
LATENT_DIMS = 2
EPOCHS = 10
N_TRAIN = 500
N_VAL = 100
train_loader, val_loader, train_labels, val_labels = example_mnist_data(N_TRAIN, N_VAL)
encoder_1 = architectures.Encoder(latent_dimensions=LATENT_DIMS, feature_size=392)
encoder_2 = architectures.Encoder(latent_dimensions=LATENT_DIMS, feature_size=392)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
94%|█████████▍| 9338880/9912422 [00:00<00:00, 93278359.67it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 94140861.86it/s]
Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 175304911.47it/s]
Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 32084412.63it/s]
Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
100%|██████████| 4542/4542 [00:00<00:00, 33305120.22it/s]
Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw
Deep MCCA (Multiset CCA)#
A multiview extension of CCA, aiming to find latent spaces that are maximally correlated across multiple representations.
dcca_mcca = DCCA(
latent_dimensions=LATENT_DIMS,
encoders=[encoder_1, encoder_2],
objective=objectives._MCCALoss,
)
trainer_mcca = pl.Trainer(
max_epochs=EPOCHS,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer_mcca.fit(dcca_mcca, train_loader, val_loader)
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Deep GCCA (Generalized CCA)#
A method that finds projections of multiple representations such that the variance explained by the canonical components is maximized.
dcca_gcca = DCCA(
latent_dimensions=LATENT_DIMS,
encoders=[encoder_1, encoder_2],
objective=objectives._GCCALoss,
)
trainer_gcca = pl.Trainer(
max_epochs=EPOCHS,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer_gcca.fit(dcca_gcca, train_loader, val_loader)
Deep TCCA (Tied CCA)#
An approach where representations share the same weight parameters during training.
dcca_tcca = DTCCA(latent_dimensions=LATENT_DIMS, encoders=[encoder_1, encoder_2])
trainer_tcca = pl.Trainer(
max_epochs=EPOCHS,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer_tcca.fit(dcca_tcca, train_loader, val_loader)
Total running time of the script: (0 minutes 7.327 seconds)