Deep CCA with more customisationΒΆ

Showing some examples of more advanced functionality with DCCA and pytorch-lightning

import numpy as np
import pytorch_lightning as pl
from torch import optim
from torch.utils.data import Subset

from cca_zoo.data import Split_MNIST_Dataset
from cca_zoo.deepmodels import DCCA, CCALightning, get_dataloaders, architectures

n_train = 500
n_val = 100
train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True)
val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))
train_dataset = Subset(train_dataset, np.arange(n_train))
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset)

# The number of latent dimensions across models
latent_dims = 2
# number of epochs for deep models
epochs = 10

# TODO add in custom architecture and schedulers and stuff to show it off
encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)

# Deep CCA
dcca = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
optimizer = optim.Adam(dcca.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1)
dcca = CCALightning(dcca, optimizer=optimizer, lr_scheduler=scheduler)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)

Out:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
 30%|##9       | 2955264/9912422 [00:00<00:00, 29551047.71it/s]
 64%|######4   | 6355968/9912422 [00:00<00:00, 30887841.37it/s]
9913344it [00:00, 42025468.61it/s]
Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
29696it [00:00, 74852194.46it/s]
Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
1649664it [00:00, 25537634.81it/s]
Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
5120it [00:00, 29826161.78it/s]
Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw


Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.2/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:  50%|#####     | 1/2 [00:00<00:00, 38.40it/s, loss=-0.22, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 2/2 [00:00<00:00, 49.92it/s, loss=-0.22, v_num=0]


Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.22, v_num=0]
Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.22, v_num=0]
Epoch 1:  50%|#####     | 1/2 [00:00<00:00, 35.28it/s, loss=-0.673, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 2/2 [00:00<00:00, 45.92it/s, loss=-0.673, v_num=0]


Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.673, v_num=0]
Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.673, v_num=0]
Epoch 2:  50%|#####     | 1/2 [00:00<00:00, 38.53it/s, loss=-0.824, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 2/2 [00:00<00:00, 49.49it/s, loss=-0.824, v_num=0]


Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.824, v_num=0]
Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.824, v_num=0]
Epoch 3:  50%|#####     | 1/2 [00:00<00:00, 35.91it/s, loss=-0.997, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 2/2 [00:00<00:00, 47.55it/s, loss=-0.997, v_num=0]


Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.997, v_num=0]
Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.997, v_num=0]
Epoch 4:  50%|#####     | 1/2 [00:00<00:00, 39.58it/s, loss=-1.1, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 2/2 [00:00<00:00, 51.23it/s, loss=-1.1, v_num=0]


Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.1, v_num=0]
Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.1, v_num=0]
Epoch 5:  50%|#####     | 1/2 [00:00<00:00, 39.13it/s, loss=-1.2, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 2/2 [00:00<00:00, 50.75it/s, loss=-1.2, v_num=0]


Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.2, v_num=0]
Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.2, v_num=0]
Epoch 6:  50%|#####     | 1/2 [00:00<00:00, 39.87it/s, loss=-1.27, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 2/2 [00:00<00:00, 51.77it/s, loss=-1.27, v_num=0]


Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.27, v_num=0]
Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.27, v_num=0]
Epoch 7:  50%|#####     | 1/2 [00:00<00:00, 39.74it/s, loss=-1.33, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 2/2 [00:00<00:00, 51.39it/s, loss=-1.33, v_num=0]


Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.33, v_num=0]
Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.33, v_num=0]
Epoch 8:  50%|#####     | 1/2 [00:00<00:00, 40.38it/s, loss=-1.38, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 2/2 [00:00<00:00, 51.55it/s, loss=-1.38, v_num=0]


Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.38, v_num=0]
Epoch 9:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.38, v_num=0]
Epoch 9:  50%|#####     | 1/2 [00:00<00:00, 40.04it/s, loss=-1.42, v_num=0]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 2/2 [00:00<00:00, 52.24it/s, loss=-1.42, v_num=0]


Epoch 9: 100%|##########| 2/2 [00:00<00:00, 33.42it/s, loss=-1.42, v_num=0]

Total running time of the script: ( 0 minutes 4.226 seconds)

Gallery generated by Sphinx-Gallery