Working with Custom Datasets in CCA-Zoo#

This script provides a guide on how to leverage custom multiview datasets with CCA-Zoo. It walks through various methods, including the use of provided utilities and the creation of a bespoke dataset class.

Key Features: - Transforming numpy arrays into CCA-Zoo compatible datasets. - Validating custom datasets. - Creating a custom dataset class from scratch. - Training a Deep CCA model on custom datasets.

import torch
import numpy as np
import lightning.pytorch as pl

Converting Numpy Arrays into Datasets#

For those looking for a straightforward method, the NumpyDataset class from CCA-Zoo is a convenient way to convert numpy arrays into valid datasets. It accepts multiple numpy arrays, each representing a distinct view, and an optional list of labels. Subsequently, these datasets can be converted into dataloaders for use in CCA-Zoo models.

from cca_zoo.deep import DCCA, architectures
from cca_zoo.deep.data import NumpyDataset, check_dataset, get_dataloaders

X = np.random.normal(size=(100, 10))
Y = np.random.normal(size=(100, 10))
Z = np.random.normal(size=(100, 10))

numpy_dataset = NumpyDataset([X, Y, Z])

Dataset Validation#

Before proceeding, it’s a good practice to validate the constructed dataset. The check_dataset function ensures that the dataset adheres to CCA-Zoo’s expected format.

check_dataset(numpy_dataset)

Creating a Custom Dataset Class#

For advanced users or specific requirements, one can create a custom dataset class. The new class should inherit from the native torch.utils.data.Dataset class. The class must implement the __getitem__ method to return a tuple consisting of multiple representations and an associated label, where representations are represented as torch tensors.

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return 10

    def __getitem__(self, index):
        return {"views": (torch.rand(10), torch.rand(10))}


custom_dataset = CustomDataset()
check_dataset(custom_dataset)

Convert Custom Dataset into DataLoader#

The get_dataloaders function can now be used to transform the custom dataset into dataloaders suitable for CCA-Zoo.

train_loader = get_dataloaders(custom_dataset, batch_size=2)

Training with Deep CCA#

Once the dataloaders are set, it’s time to configure and train a Deep CCA model.

LATENT_DIMS = 1
EPOCHS = 10

encoder_1 = architectures.Encoder(latent_dimensions=LATENT_DIMS, feature_size=10)
encoder_2 = architectures.Encoder(latent_dimensions=LATENT_DIMS, feature_size=10)

dcca = DCCA(latent_dimensions=LATENT_DIMS, encoders=[encoder_1, encoder_2])
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    enable_checkpointing=False,
    enable_model_summary=False,
    enable_progress_bar=False,
)
trainer.fit(dcca, train_loader)
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/latest/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

Total running time of the script: (0 minutes 0.483 seconds)

Gallery generated by Sphinx-Gallery