Deep CCAΒΆ

This example demonstrates how to easily train Deep CCA models and variants

import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from torch.utils.data import Subset
from cca_zoo.data import Split_MNIST_Dataset
from cca_zoo.deepmodels import (
    DCCA,
    CCALightning,
    get_dataloaders,
    architectures,
    DCCA_NOI,
    DCCA_SDL,
    BarlowTwins,
)


def plot_latent_label(model, dataloader, num_batches=100):
    fig, ax = plt.subplots(ncols=model.latent_dims)
    for j in range(model.latent_dims):
        ax[j].set_title(f'Dimension {j}')
        ax[j].set_xlabel('View 1')
        ax[j].set_ylabel('View 2')
    for i, (data, label) in enumerate(dataloader):
        z = model(*data)
        zx, zy = z
        zx = zx.to('cpu').detach().numpy()
        zy = zy.to('cpu').detach().numpy()
        for j in range(model.latent_dims):
            ax[j].scatter(zx[:, j], zy[:, j], c=label.numpy(), cmap='tab10')
        if i > num_batches:
            plt.colorbar()
            break


n_train = 500
n_val = 100
train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True)
val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))
train_dataset = Subset(train_dataset, np.arange(n_train))
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset, batch_size=128)

# The number of latent dimensions across models
latent_dims = 2
# number of epochs for deep models
epochs = 20

encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)

Deep CCA

dcca = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
dcca = CCALightning(dcca)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)
plot_latent_label(dcca.model, train_loader)
plt.suptitle('DCCA')
plt.show()
DCCA, Dimension 0, Dimension 1

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.1/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 68.26it/s, loss=-0.403, v_num=4]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 79.64it/s, loss=-0.599, v_num=4]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 97.16it/s, loss=-0.779, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 89.61it/s, loss=-0.779, v_num=4]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=-0.779, v_num=4]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-0.779, v_num=4]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 70.37it/s, loss=-0.936, v_num=4]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 83.23it/s, loss=-1.06, v_num=4]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 100.45it/s, loss=-1.16, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 92.23it/s, loss=-1.16, v_num=4]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.16, v_num=4]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.16, v_num=4]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 71.60it/s, loss=-1.22, v_num=4]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 81.92it/s, loss=-1.27, v_num=4]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 99.22it/s, loss=-1.32, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 90.77it/s, loss=-1.32, v_num=4]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.32, v_num=4]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.32, v_num=4]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 70.84it/s, loss=-1.35, v_num=4]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 81.49it/s, loss=-1.39, v_num=4]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 98.90it/s, loss=-1.42, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 91.56it/s, loss=-1.42, v_num=4]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.42, v_num=4]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.42, v_num=4]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 71.58it/s, loss=-1.44, v_num=4]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 84.34it/s, loss=-1.47, v_num=4]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 101.86it/s, loss=-1.49, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 93.45it/s, loss=-1.49, v_num=4]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.49, v_num=4]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.49, v_num=4]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 71.66it/s, loss=-1.51, v_num=4]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 80.84it/s, loss=-1.53, v_num=4]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 98.34it/s, loss=-1.55, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 91.18it/s, loss=-1.55, v_num=4]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.55, v_num=4]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.55, v_num=4]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 71.88it/s, loss=-1.56, v_num=4]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 84.58it/s, loss=-1.58, v_num=4]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 102.28it/s, loss=-1.65, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 93.81it/s, loss=-1.65, v_num=4]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.65, v_num=4]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.65, v_num=4]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 71.30it/s, loss=-1.71, v_num=4]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 83.89it/s, loss=-1.74, v_num=4]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 101.29it/s, loss=-1.77, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 93.04it/s, loss=-1.77, v_num=4]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.77, v_num=4]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.77, v_num=4]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 71.42it/s, loss=-1.78, v_num=4]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 83.57it/s, loss=-1.8, v_num=4]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 100.62it/s, loss=-1.81, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 91.58it/s, loss=-1.81, v_num=4]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.81, v_num=4]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.81, v_num=4]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 70.70it/s, loss=-1.82, v_num=4]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 82.80it/s, loss=-1.83, v_num=4]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 100.01it/s, loss=-1.84, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 91.20it/s, loss=-1.84, v_num=4]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.84, v_num=4]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.84, v_num=4]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 71.17it/s, loss=-1.85, v_num=4]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 83.76it/s, loss=-1.86, v_num=4]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 101.11it/s, loss=-1.87, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 92.80it/s, loss=-1.87, v_num=4]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.87, v_num=4]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.87, v_num=4]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 71.33it/s, loss=-1.88, v_num=4]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 84.04it/s, loss=-1.88, v_num=4]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 101.80it/s, loss=-1.89, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 93.26it/s, loss=-1.89, v_num=4]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.89, v_num=4]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.89, v_num=4]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 70.73it/s, loss=-1.89, v_num=4]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 81.39it/s, loss=-1.9, v_num=4]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 98.22it/s, loss=-1.9, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 90.83it/s, loss=-1.9, v_num=4]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=4]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=4]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 71.31it/s, loss=-1.9, v_num=4]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 83.60it/s, loss=-1.91, v_num=4]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 100.18it/s, loss=-1.91, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 89.49it/s, loss=-1.91, v_num=4]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=4]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=4]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 70.06it/s, loss=-1.91, v_num=4]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 81.28it/s, loss=-1.91, v_num=4]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 98.42it/s, loss=-1.92, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 91.13it/s, loss=-1.92, v_num=4]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=4]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=4]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 72.12it/s, loss=-1.92, v_num=4]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 84.88it/s, loss=-1.92, v_num=4]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 102.29it/s, loss=-1.93, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 93.60it/s, loss=-1.93, v_num=4]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 71.42it/s, loss=-1.93, v_num=4]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 80.98it/s, loss=-1.93, v_num=4]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 98.62it/s, loss=-1.93, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 90.65it/s, loss=-1.93, v_num=4]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 72.98it/s, loss=-1.94, v_num=4]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 85.78it/s, loss=-1.94, v_num=4]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 103.47it/s, loss=-1.94, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 94.52it/s, loss=-1.94, v_num=4]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 72.05it/s, loss=-1.94, v_num=4]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 84.80it/s, loss=-1.94, v_num=4]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 102.50it/s, loss=-1.94, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 93.46it/s, loss=-1.94, v_num=4]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 73.23it/s, loss=-1.95, v_num=4]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 86.15it/s, loss=-1.95, v_num=4]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 103.57it/s, loss=-1.95, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 94.89it/s, loss=-1.95, v_num=4]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 67.61it/s, loss=-1.95, v_num=4]

Deep CCA by Non-Linear Orthogonal Iterations

dcca_noi = DCCA_NOI(
    latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2]
)
dcca_noi = CCALightning(dcca_noi)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca_noi, train_loader, val_loader)
plot_latent_label(dcca_noi.model, train_loader)
plt.title('DCCA by Non-Linear Orthogonal Iterations')
plt.show()
Dimension 0, DCCA by Non-Linear Orthogonal Iterations

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.1/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 77.01it/s, loss=34.6, v_num=5]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 89.77it/s, loss=26.3, v_num=5]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 110.73it/s, loss=20.9, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 99.16it/s, loss=20.9, v_num=5]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=20.9, v_num=5]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=20.9, v_num=5]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 76.22it/s, loss=17.8, v_num=5]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 91.43it/s, loss=15.8, v_num=5]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 112.18it/s, loss=14.5, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 99.87it/s, loss=14.5, v_num=5]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.5, v_num=5]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.5, v_num=5]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 77.55it/s, loss=13.7, v_num=5]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 92.36it/s, loss=12.8, v_num=5]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 113.53it/s, loss=11.8, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 101.11it/s, loss=11.8, v_num=5]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=11.8, v_num=5]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=11.8, v_num=5]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 79.56it/s, loss=10.9, v_num=5]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 93.71it/s, loss=10.2, v_num=5]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 114.70it/s, loss=9.49, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 101.48it/s, loss=9.49, v_num=5]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.49, v_num=5]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.49, v_num=5]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 78.40it/s, loss=8.88, v_num=5]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 92.26it/s, loss=8.33, v_num=5]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 113.28it/s, loss=7.83, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 101.25it/s, loss=7.83, v_num=5]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.83, v_num=5]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.83, v_num=5]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 78.39it/s, loss=7.42, v_num=5]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 90.81it/s, loss=7.05, v_num=5]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 109.67it/s, loss=6.72, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 97.85it/s, loss=6.72, v_num=5]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.72, v_num=5]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.72, v_num=5]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 77.36it/s, loss=6.43, v_num=5]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 91.94it/s, loss=6.17, v_num=5]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 112.76it/s, loss=4.5, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 99.82it/s, loss=4.5, v_num=5]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=4.5, v_num=5]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=4.5, v_num=5]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 78.78it/s, loss=3.66, v_num=5]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 93.28it/s, loss=3.2, v_num=5]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 114.37it/s, loss=2.83, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 101.07it/s, loss=2.83, v_num=5]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.83, v_num=5]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.83, v_num=5]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 77.32it/s, loss=2.47, v_num=5]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 92.53it/s, loss=2.1, v_num=5]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 113.67it/s, loss=1.71, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 100.68it/s, loss=1.71, v_num=5]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.71, v_num=5]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.71, v_num=5]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 76.79it/s, loss=1.44, v_num=5]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 91.95it/s, loss=1.27, v_num=5]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 112.92it/s, loss=1.14, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 100.43it/s, loss=1.14, v_num=5]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.14, v_num=5]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.14, v_num=5]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 76.56it/s, loss=1.04, v_num=5]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 91.79it/s, loss=0.978, v_num=5]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 112.72it/s, loss=0.939, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 100.21it/s, loss=0.939, v_num=5]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.939, v_num=5]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.939, v_num=5]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 78.89it/s, loss=0.916, v_num=5]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 93.39it/s, loss=0.899, v_num=5]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 114.45it/s, loss=0.872, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 101.46it/s, loss=0.872, v_num=5]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.872, v_num=5]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.872, v_num=5]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 79.33it/s, loss=0.851, v_num=5]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 92.33it/s, loss=0.826, v_num=5]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 113.27it/s, loss=0.798, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 101.28it/s, loss=0.798, v_num=5]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.798, v_num=5]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.798, v_num=5]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 78.72it/s, loss=0.764, v_num=5]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 91.27it/s, loss=0.733, v_num=5]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 112.18it/s, loss=0.702, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 98.00it/s, loss=0.702, v_num=5]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.702, v_num=5]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.702, v_num=5]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 75.82it/s, loss=0.682, v_num=5]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 90.84it/s, loss=0.663, v_num=5]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 111.84it/s, loss=0.646, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 99.84it/s, loss=0.646, v_num=5]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.646, v_num=5]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.646, v_num=5]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 79.64it/s, loss=0.633, v_num=5]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 93.26it/s, loss=0.618, v_num=5]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 114.55it/s, loss=0.607, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 101.50it/s, loss=0.607, v_num=5]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.607, v_num=5]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.607, v_num=5]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 80.22it/s, loss=0.591, v_num=5]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 94.18it/s, loss=0.581, v_num=5]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 115.21it/s, loss=0.574, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 102.47it/s, loss=0.574, v_num=5]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.574, v_num=5]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.574, v_num=5]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 79.19it/s, loss=0.562, v_num=5]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 92.13it/s, loss=0.55, v_num=5]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 113.42it/s, loss=0.536, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 101.13it/s, loss=0.536, v_num=5]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.536, v_num=5]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.536, v_num=5]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 80.22it/s, loss=0.527, v_num=5]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 94.95it/s, loss=0.513, v_num=5]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 116.04it/s, loss=0.497, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 102.77it/s, loss=0.497, v_num=5]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.497, v_num=5]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.497, v_num=5]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 79.82it/s, loss=0.483, v_num=5]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 94.73it/s, loss=0.471, v_num=5]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 115.73it/s, loss=0.458, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 102.57it/s, loss=0.458, v_num=5]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 72.11it/s, loss=0.458, v_num=5]

Deep CCA by Stochastic Decorrelation Loss

dcca_sdl = DCCA_SDL(
    latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2]
)
dcca_sdl = CCALightning(dcca_sdl)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca_sdl, train_loader, val_loader)
plot_latent_label(dcca_sdl.model, train_loader)
plt.title('DCCA by Stochastic Decorrelation')
plt.show()
Dimension 0, DCCA by Stochastic Decorrelation

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.1/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 77.66it/s, loss=43, v_num=6]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 91.54it/s, loss=35, v_num=6]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 113.05it/s, loss=29.5, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 100.81it/s, loss=29.5, v_num=6]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=29.5, v_num=6]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=29.5, v_num=6]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 78.36it/s, loss=24.1, v_num=6]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 93.43it/s, loss=22.8, v_num=6]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 114.55it/s, loss=22, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 102.58it/s, loss=22, v_num=6]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=22, v_num=6]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=22, v_num=6]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 79.34it/s, loss=21.1, v_num=6]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 92.17it/s, loss=21.3, v_num=6]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 113.51it/s, loss=21.3, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 102.48it/s, loss=21.3, v_num=6]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.3, v_num=6]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.3, v_num=6]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 79.33it/s, loss=20.8, v_num=6]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 94.49it/s, loss=19.4, v_num=6]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 115.09it/s, loss=18.4, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 103.52it/s, loss=18.4, v_num=6]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=18.4, v_num=6]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=18.4, v_num=6]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 79.99it/s, loss=17.6, v_num=6]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 94.97it/s, loss=17.3, v_num=6]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 115.96it/s, loss=17.1, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 104.99it/s, loss=17.1, v_num=6]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=17.1, v_num=6]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=17.1, v_num=6]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 80.53it/s, loss=16.6, v_num=6]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 93.82it/s, loss=16.1, v_num=6]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 115.40it/s, loss=15.3, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 102.97it/s, loss=15.3, v_num=6]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=15.3, v_num=6]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=15.3, v_num=6]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 80.12it/s, loss=14.9, v_num=6]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 95.16it/s, loss=14.9, v_num=6]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 116.74it/s, loss=13.4, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 104.38it/s, loss=13.4, v_num=6]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.4, v_num=6]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.4, v_num=6]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 77.68it/s, loss=12.6, v_num=6]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 92.24it/s, loss=12, v_num=6]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 113.14it/s, loss=12.1, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 100.91it/s, loss=12.1, v_num=6]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=12.1, v_num=6]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=12.1, v_num=6]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 78.33it/s, loss=11.7, v_num=6]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 93.45it/s, loss=11.1, v_num=6]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 114.83it/s, loss=10.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 101.17it/s, loss=10.6, v_num=6]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.6, v_num=6]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.6, v_num=6]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 79.03it/s, loss=9.76, v_num=6]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 93.95it/s, loss=9.23, v_num=6]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 114.94it/s, loss=8.88, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 102.69it/s, loss=8.88, v_num=6]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.88, v_num=6]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.88, v_num=6]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 78.55it/s, loss=8.84, v_num=6]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 91.39it/s, loss=8.67, v_num=6]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 112.59it/s, loss=8.64, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 101.34it/s, loss=8.64, v_num=6]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.64, v_num=6]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.64, v_num=6]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 79.10it/s, loss=8.14, v_num=6]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 93.51it/s, loss=7.64, v_num=6]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 114.67it/s, loss=7.32, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 102.08it/s, loss=7.32, v_num=6]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.32, v_num=6]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.32, v_num=6]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 77.03it/s, loss=7.1, v_num=6]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 91.64it/s, loss=7.39, v_num=6]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 112.48it/s, loss=7.64, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 100.52it/s, loss=7.64, v_num=6]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.64, v_num=6]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.64, v_num=6]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 74.00it/s, loss=7.12, v_num=6]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 87.02it/s, loss=6.62, v_num=6]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 107.51it/s, loss=6.27, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 97.48it/s, loss=6.27, v_num=6]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.27, v_num=6]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.27, v_num=6]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 75.95it/s, loss=6.37, v_num=6]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 90.64it/s, loss=6.34, v_num=6]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 111.32it/s, loss=6.29, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 100.05it/s, loss=6.29, v_num=6]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.29, v_num=6]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.29, v_num=6]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 76.87it/s, loss=6.21, v_num=6]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 91.59it/s, loss=6.38, v_num=6]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 112.32it/s, loss=6.34, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 100.15it/s, loss=6.34, v_num=6]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.34, v_num=6]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.34, v_num=6]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 76.16it/s, loss=6.12, v_num=6]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 87.90it/s, loss=6.09, v_num=6]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 108.52it/s, loss=6.42, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 97.61it/s, loss=6.42, v_num=6]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.42, v_num=6]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.42, v_num=6]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 78.21it/s, loss=6.93, v_num=6]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 93.21it/s, loss=6.99, v_num=6]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 114.43it/s, loss=7.17, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 102.31it/s, loss=7.17, v_num=6]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.17, v_num=6]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.17, v_num=6]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 79.58it/s, loss=7.33, v_num=6]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 94.63it/s, loss=7.6, v_num=6]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 116.21it/s, loss=7.5, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 103.11it/s, loss=7.5, v_num=6]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.5, v_num=6]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.5, v_num=6]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 79.44it/s, loss=7.35, v_num=6]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 94.77it/s, loss=6.96, v_num=6]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 116.34it/s, loss=7.16, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 104.10it/s, loss=7.16, v_num=6]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 71.77it/s, loss=7.16, v_num=6]

Deep CCA by Barlow Twins

barlowtwins = BarlowTwins(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
barlowtwins = CCALightning(barlowtwins)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)
plot_latent_label(dcca_sdl.model, train_loader)
plt.title('DCCA by Barlow Twins')
plt.show()
Dimension 0, DCCA by Barlow Twins

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.1/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 72.47it/s, loss=-1.26, v_num=7]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 85.05it/s, loss=-1.4, v_num=7]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 102.92it/s, loss=-1.46, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 94.51it/s, loss=-1.46, v_num=7]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.46, v_num=7]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.46, v_num=7]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 73.91it/s, loss=-1.51, v_num=7]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 85.12it/s, loss=-1.55, v_num=7]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 103.04it/s, loss=-1.58, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 94.31it/s, loss=-1.58, v_num=7]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.58, v_num=7]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.58, v_num=7]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 72.82it/s, loss=-1.6, v_num=7]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 85.67it/s, loss=-1.62, v_num=7]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 103.30it/s, loss=-1.64, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 94.62it/s, loss=-1.64, v_num=7]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.64, v_num=7]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.64, v_num=7]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 72.57it/s, loss=-1.65, v_num=7]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 85.33it/s, loss=-1.67, v_num=7]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 102.73it/s, loss=-1.68, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 91.77it/s, loss=-1.68, v_num=7]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.68, v_num=7]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.68, v_num=7]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 72.51it/s, loss=-1.69, v_num=7]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 85.55it/s, loss=-1.7, v_num=7]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 103.15it/s, loss=-1.72, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 94.30it/s, loss=-1.72, v_num=7]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.72, v_num=7]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.72, v_num=7]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 74.28it/s, loss=-1.72, v_num=7]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 86.07it/s, loss=-1.73, v_num=7]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 103.63it/s, loss=-1.74, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 95.26it/s, loss=-1.74, v_num=7]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.74, v_num=7]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.74, v_num=7]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 72.21it/s, loss=-1.75, v_num=7]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 84.38it/s, loss=-1.76, v_num=7]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 102.10it/s, loss=-1.79, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 93.69it/s, loss=-1.79, v_num=7]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.79, v_num=7]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.79, v_num=7]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 74.57it/s, loss=-1.8, v_num=7]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 85.49it/s, loss=-1.82, v_num=7]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 103.38it/s, loss=-1.83, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 95.35it/s, loss=-1.83, v_num=7]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.83, v_num=7]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.83, v_num=7]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 72.37it/s, loss=-1.84, v_num=7]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 85.42it/s, loss=-1.85, v_num=7]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 102.90it/s, loss=-1.86, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 94.28it/s, loss=-1.86, v_num=7]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.86, v_num=7]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.86, v_num=7]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 66.46it/s, loss=-1.87, v_num=7]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 79.98it/s, loss=-1.88, v_num=7]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 97.32it/s, loss=-1.88, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 89.93it/s, loss=-1.88, v_num=7]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.88, v_num=7]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.88, v_num=7]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 72.63it/s, loss=-1.89, v_num=7]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 85.19it/s, loss=-1.9, v_num=7]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 102.64it/s, loss=-1.9, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 93.32it/s, loss=-1.9, v_num=7]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=7]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=7]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 71.63it/s, loss=-1.91, v_num=7]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 83.93it/s, loss=-1.91, v_num=7]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 101.01it/s, loss=-1.92, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 92.76it/s, loss=-1.92, v_num=7]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=7]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=7]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 71.91it/s, loss=-1.92, v_num=7]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 82.85it/s, loss=-1.92, v_num=7]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 100.06it/s, loss=-1.93, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 92.56it/s, loss=-1.93, v_num=7]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=7]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=7]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 71.51it/s, loss=-1.93, v_num=7]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 82.18it/s, loss=-1.94, v_num=7]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 99.48it/s, loss=-1.94, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 91.28it/s, loss=-1.94, v_num=7]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=7]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=7]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 71.52it/s, loss=-1.94, v_num=7]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 84.06it/s, loss=-1.94, v_num=7]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 101.18it/s, loss=-1.95, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 93.23it/s, loss=-1.95, v_num=7]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 71.71it/s, loss=-1.95, v_num=7]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 83.25it/s, loss=-1.95, v_num=7]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 100.30it/s, loss=-1.95, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 92.14it/s, loss=-1.95, v_num=7]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 71.87it/s, loss=-1.95, v_num=7]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 83.27it/s, loss=-1.96, v_num=7]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 99.88it/s, loss=-1.96, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 91.41it/s, loss=-1.96, v_num=7]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 71.65it/s, loss=-1.96, v_num=7]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 82.33it/s, loss=-1.96, v_num=7]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 99.62it/s, loss=-1.96, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 92.22it/s, loss=-1.96, v_num=7]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 72.63it/s, loss=-1.96, v_num=7]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 85.37it/s, loss=-1.96, v_num=7]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 102.81it/s, loss=-1.96, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 94.63it/s, loss=-1.96, v_num=7]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 73.28it/s, loss=-1.97, v_num=7]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 86.25it/s, loss=-1.97, v_num=7]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 103.51it/s, loss=-1.97, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 95.07it/s, loss=-1.97, v_num=7]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 67.95it/s, loss=-1.97, v_num=7]

Total running time of the script: ( 0 minutes 10.423 seconds)

Gallery generated by Sphinx-Gallery