Deep CCAΒΆ

This example demonstrates how to easily train Deep CCA models and variants

import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from torch.utils.data import Subset
from cca_zoo.data import Split_MNIST_Dataset
from cca_zoo.deepmodels import (
    DCCA,
    CCALightning,
    get_dataloaders,
    architectures,
    DCCA_NOI,
    DCCA_SDL,
    BarlowTwins,
)


def plot_latent_label(model, dataloader, num_batches=100):
    fig, ax = plt.subplots(ncols=model.latent_dims)
    for j in range(model.latent_dims):
        ax[j].set_title(f"Dimension {j}")
        ax[j].set_xlabel("View 1")
        ax[j].set_ylabel("View 2")
    for i, (data, label) in enumerate(dataloader):
        z = model(*data)
        zx, zy = z
        zx = zx.to("cpu").detach().numpy()
        zy = zy.to("cpu").detach().numpy()
        for j in range(model.latent_dims):
            ax[j].scatter(zx[:, j], zy[:, j], c=label.numpy(), cmap="tab10")
        if i > num_batches:
            plt.colorbar()
            break


n_train = 500
n_val = 100
train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True)
val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))
train_dataset = Subset(train_dataset, np.arange(n_train))
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset, batch_size=128)

# The number of latent dimensions across models
latent_dims = 2
# number of epochs for deep models
epochs = 20

encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)

Deep CCA

dcca = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
dcca = CCALightning(dcca)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)
plot_latent_label(dcca.model, train_loader)
plt.suptitle("DCCA")
plt.show()
DCCA, Dimension 0, Dimension 1

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/dev/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:413: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 74.17it/s, loss=-0.349, v_num=4]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 88.36it/s, loss=-0.513, v_num=4]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 108.22it/s, loss=-0.717, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 98.70it/s, loss=-0.717, v_num=4]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=-0.717, v_num=4]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-0.717, v_num=4]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 77.60it/s, loss=-0.913, v_num=4]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 91.99it/s, loss=-1.03, v_num=4]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 111.68it/s, loss=-1.11, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 102.12it/s, loss=-1.11, v_num=4]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.11, v_num=4]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.11, v_num=4]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 77.89it/s, loss=-1.19, v_num=4]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 92.71it/s, loss=-1.25, v_num=4]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 112.36it/s, loss=-1.3, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 102.22it/s, loss=-1.3, v_num=4]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.3, v_num=4]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.3, v_num=4]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 77.48it/s, loss=-1.34, v_num=4]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 91.59it/s, loss=-1.38, v_num=4]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 111.31it/s, loss=-1.41, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 101.97it/s, loss=-1.41, v_num=4]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.41, v_num=4]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.41, v_num=4]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 78.98it/s, loss=-1.44, v_num=4]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 93.81it/s, loss=-1.46, v_num=4]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 114.06it/s, loss=-1.49, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 104.37it/s, loss=-1.49, v_num=4]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.49, v_num=4]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.49, v_num=4]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 73.42it/s, loss=-1.51, v_num=4]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 88.64it/s, loss=-1.52, v_num=4]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 108.19it/s, loss=-1.54, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 100.26it/s, loss=-1.54, v_num=4]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.54, v_num=4]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.54, v_num=4]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 78.09it/s, loss=-1.56, v_num=4]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 90.74it/s, loss=-1.57, v_num=4]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 111.00it/s, loss=-1.65, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 100.13it/s, loss=-1.65, v_num=4]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.65, v_num=4]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.65, v_num=4]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 77.50it/s, loss=-1.7, v_num=4]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 91.51it/s, loss=-1.74, v_num=4]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 110.86it/s, loss=-1.76, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 102.20it/s, loss=-1.76, v_num=4]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.76, v_num=4]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.76, v_num=4]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 78.25it/s, loss=-1.78, v_num=4]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 92.83it/s, loss=-1.8, v_num=4]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 112.58it/s, loss=-1.81, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 102.34it/s, loss=-1.81, v_num=4]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.81, v_num=4]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.81, v_num=4]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 77.52it/s, loss=-1.82, v_num=4]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 91.28it/s, loss=-1.83, v_num=4]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 110.64it/s, loss=-1.84, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 101.47it/s, loss=-1.84, v_num=4]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.84, v_num=4]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.84, v_num=4]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 77.95it/s, loss=-1.85, v_num=4]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 92.87it/s, loss=-1.85, v_num=4]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 113.16it/s, loss=-1.86, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 103.53it/s, loss=-1.86, v_num=4]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.86, v_num=4]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.86, v_num=4]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 77.56it/s, loss=-1.87, v_num=4]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 92.18it/s, loss=-1.87, v_num=4]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 111.75it/s, loss=-1.88, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 102.58it/s, loss=-1.88, v_num=4]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.88, v_num=4]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.88, v_num=4]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 77.75it/s, loss=-1.88, v_num=4]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 92.16it/s, loss=-1.89, v_num=4]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 112.03it/s, loss=-1.89, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 102.56it/s, loss=-1.89, v_num=4]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.89, v_num=4]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.89, v_num=4]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 77.72it/s, loss=-1.9, v_num=4]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 92.58it/s, loss=-1.9, v_num=4]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 112.87it/s, loss=-1.9, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 103.50it/s, loss=-1.9, v_num=4]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=4]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=4]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 76.95it/s, loss=-1.91, v_num=4]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 90.28it/s, loss=-1.91, v_num=4]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 108.56it/s, loss=-1.91, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 99.12it/s, loss=-1.91, v_num=4]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=4]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=4]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 76.80it/s, loss=-1.92, v_num=4]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 90.56it/s, loss=-1.92, v_num=4]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 110.21it/s, loss=-1.92, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 100.79it/s, loss=-1.92, v_num=4]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=4]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=4]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 77.71it/s, loss=-1.92, v_num=4]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 91.99it/s, loss=-1.93, v_num=4]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 108.76it/s, loss=-1.93, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 100.52it/s, loss=-1.93, v_num=4]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 76.82it/s, loss=-1.93, v_num=4]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 91.10it/s, loss=-1.93, v_num=4]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 110.05it/s, loss=-1.94, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 100.62it/s, loss=-1.94, v_num=4]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 75.80it/s, loss=-1.94, v_num=4]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 90.63it/s, loss=-1.94, v_num=4]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 110.54it/s, loss=-1.94, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 101.49it/s, loss=-1.94, v_num=4]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 76.14it/s, loss=-1.94, v_num=4]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 90.47it/s, loss=-1.95, v_num=4]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 109.95it/s, loss=-1.95, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 100.63it/s, loss=-1.95, v_num=4]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 71.14it/s, loss=-1.95, v_num=4]

Deep CCA by Non-Linear Orthogonal Iterations

dcca_noi = DCCA_NOI(
    latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2]
)
dcca_noi = CCALightning(dcca_noi)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca_noi, train_loader, val_loader)
plot_latent_label(dcca_noi.model, train_loader)
plt.title("DCCA by Non-Linear Orthogonal Iterations")
plt.show()
Dimension 0, DCCA by Non-Linear Orthogonal Iterations

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/dev/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:413: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 77.39it/s, loss=24.9, v_num=5]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 94.49it/s, loss=18.9, v_num=5]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 117.23it/s, loss=14.9, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 107.00it/s, loss=14.9, v_num=5]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.9, v_num=5]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.9, v_num=5]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 79.27it/s, loss=12.4, v_num=5]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 96.40it/s, loss=10.8, v_num=5]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 118.55it/s, loss=9.62, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 106.96it/s, loss=9.62, v_num=5]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.62, v_num=5]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.62, v_num=5]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 82.33it/s, loss=8.69, v_num=5]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 99.09it/s, loss=7.9, v_num=5]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 120.88it/s, loss=7.27, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 109.59it/s, loss=7.27, v_num=5]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.27, v_num=5]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.27, v_num=5]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 84.74it/s, loss=6.73, v_num=5]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 100.33it/s, loss=6.28, v_num=5]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 122.59it/s, loss=5.89, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 109.89it/s, loss=5.89, v_num=5]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=5.89, v_num=5]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=5.89, v_num=5]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 83.39it/s, loss=5.55, v_num=5]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 99.83it/s, loss=5.25, v_num=5]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 122.60it/s, loss=4.99, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 109.99it/s, loss=4.99, v_num=5]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=4.99, v_num=5]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=4.99, v_num=5]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 82.02it/s, loss=4.75, v_num=5]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 98.43it/s, loss=4.54, v_num=5]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 120.80it/s, loss=4.33, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 108.74it/s, loss=4.33, v_num=5]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=4.33, v_num=5]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=4.33, v_num=5]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 83.47it/s, loss=4.13, v_num=5]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 99.85it/s, loss=3.96, v_num=5]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 122.56it/s, loss=2.75, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 108.91it/s, loss=2.75, v_num=5]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.75, v_num=5]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.75, v_num=5]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 82.32it/s, loss=2.14, v_num=5]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 97.95it/s, loss=1.84, v_num=5]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 120.49it/s, loss=1.64, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 108.18it/s, loss=1.64, v_num=5]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.64, v_num=5]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.64, v_num=5]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 81.45it/s, loss=1.44, v_num=5]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 98.14it/s, loss=1.28, v_num=5]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 120.39it/s, loss=1.16, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 106.78it/s, loss=1.16, v_num=5]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.16, v_num=5]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.16, v_num=5]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 81.85it/s, loss=1.07, v_num=5]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 98.27it/s, loss=0.982, v_num=5]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 120.41it/s, loss=0.911, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 108.01it/s, loss=0.911, v_num=5]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.911, v_num=5]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.911, v_num=5]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 81.55it/s, loss=0.842, v_num=5]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 98.00it/s, loss=0.782, v_num=5]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 120.53it/s, loss=0.731, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 107.96it/s, loss=0.731, v_num=5]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.731, v_num=5]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.731, v_num=5]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 83.76it/s, loss=0.678, v_num=5]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 100.44it/s, loss=0.633, v_num=5]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 123.37it/s, loss=0.592, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 110.73it/s, loss=0.592, v_num=5]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.592, v_num=5]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.592, v_num=5]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 83.57it/s, loss=0.552, v_num=5]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 99.87it/s, loss=0.532, v_num=5]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 122.78it/s, loss=0.518, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 110.18it/s, loss=0.518, v_num=5]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.518, v_num=5]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.518, v_num=5]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 79.64it/s, loss=0.494, v_num=5]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 95.81it/s, loss=0.474, v_num=5]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 117.90it/s, loss=0.455, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 106.47it/s, loss=0.455, v_num=5]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.455, v_num=5]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.455, v_num=5]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 80.14it/s, loss=0.431, v_num=5]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 96.67it/s, loss=0.408, v_num=5]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 118.98it/s, loss=0.39, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 107.49it/s, loss=0.39, v_num=5]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.39, v_num=5]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.39, v_num=5]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 82.03it/s, loss=0.37, v_num=5]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 98.47it/s, loss=0.355, v_num=5]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 119.58it/s, loss=0.338, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 108.31it/s, loss=0.338, v_num=5]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.338, v_num=5]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.338, v_num=5]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 81.45it/s, loss=0.324, v_num=5]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 95.91it/s, loss=0.311, v_num=5]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 114.34it/s, loss=0.299, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 104.23it/s, loss=0.299, v_num=5]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.299, v_num=5]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.299, v_num=5]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 79.90it/s, loss=0.289, v_num=5]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 95.39it/s, loss=0.281, v_num=5]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 116.82it/s, loss=0.273, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 105.45it/s, loss=0.273, v_num=5]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.273, v_num=5]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.273, v_num=5]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 80.22it/s, loss=0.264, v_num=5]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 95.69it/s, loss=0.254, v_num=5]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 116.21it/s, loss=0.249, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 105.06it/s, loss=0.249, v_num=5]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.249, v_num=5]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.249, v_num=5]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 82.80it/s, loss=0.241, v_num=5]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 99.11it/s, loss=0.232, v_num=5]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 121.83it/s, loss=0.224, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 109.34it/s, loss=0.224, v_num=5]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 75.93it/s, loss=0.224, v_num=5]

Deep CCA by Stochastic Decorrelation Loss

dcca_sdl = DCCA_SDL(
    latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2]
)
dcca_sdl = CCALightning(dcca_sdl)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca_sdl, train_loader, val_loader)
plot_latent_label(dcca_sdl.model, train_loader)
plt.title("DCCA by Stochastic Decorrelation")
plt.show()
Dimension 0, DCCA by Stochastic Decorrelation

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/dev/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:413: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 76.95it/s, loss=10.4, v_num=6]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 93.09it/s, loss=18.1, v_num=6]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 115.08it/s, loss=25, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 105.66it/s, loss=25, v_num=6]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=25, v_num=6]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=25, v_num=6]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 81.23it/s, loss=27.5, v_num=6]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 97.15it/s, loss=27.5, v_num=6]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 118.93it/s, loss=27, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 108.38it/s, loss=27, v_num=6]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=27, v_num=6]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=27, v_num=6]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 81.44it/s, loss=26.7, v_num=6]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 96.93it/s, loss=25.4, v_num=6]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 118.54it/s, loss=23.7, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 107.77it/s, loss=23.7, v_num=6]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=23.7, v_num=6]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=23.7, v_num=6]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 79.52it/s, loss=22.7, v_num=6]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 94.94it/s, loss=22.2, v_num=6]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 116.01it/s, loss=21.7, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 105.54it/s, loss=21.7, v_num=6]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.7, v_num=6]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.7, v_num=6]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 81.75it/s, loss=21.2, v_num=6]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 97.28it/s, loss=20.4, v_num=6]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 118.77it/s, loss=19.9, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 107.74it/s, loss=19.9, v_num=6]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=19.9, v_num=6]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=19.9, v_num=6]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 81.08it/s, loss=19.4, v_num=6]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 96.79it/s, loss=19, v_num=6]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 117.70it/s, loss=18.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 106.98it/s, loss=18.6, v_num=6]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=18.6, v_num=6]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=18.6, v_num=6]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 82.02it/s, loss=18.1, v_num=6]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 97.50it/s, loss=17.8, v_num=6]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 118.80it/s, loss=18, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 107.90it/s, loss=18, v_num=6]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=18, v_num=6]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=18, v_num=6]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 79.65it/s, loss=17.5, v_num=6]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 95.68it/s, loss=16, v_num=6]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 117.29it/s, loss=14.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 106.56it/s, loss=14.6, v_num=6]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.6, v_num=6]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.6, v_num=6]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 80.17it/s, loss=13.6, v_num=6]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 95.65it/s, loss=12.9, v_num=6]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 116.71it/s, loss=12, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 106.15it/s, loss=12, v_num=6]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=12, v_num=6]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=12, v_num=6]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 80.50it/s, loss=11.6, v_num=6]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 96.01it/s, loss=11.3, v_num=6]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 117.11it/s, loss=10.9, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 105.74it/s, loss=10.9, v_num=6]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.9, v_num=6]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.9, v_num=6]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 81.59it/s, loss=10.2, v_num=6]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 97.47it/s, loss=9.7, v_num=6]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 119.41it/s, loss=9.45, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 108.75it/s, loss=9.45, v_num=6]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.45, v_num=6]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.45, v_num=6]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 82.22it/s, loss=9.61, v_num=6]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 98.44it/s, loss=9.51, v_num=6]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 120.36it/s, loss=9.29, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 109.14it/s, loss=9.29, v_num=6]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.29, v_num=6]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.29, v_num=6]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 82.19it/s, loss=9.09, v_num=6]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 97.83it/s, loss=8.85, v_num=6]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 119.29it/s, loss=8.89, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 108.64it/s, loss=8.89, v_num=6]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.89, v_num=6]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.89, v_num=6]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 80.66it/s, loss=8.54, v_num=6]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 96.28it/s, loss=8.28, v_num=6]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 117.93it/s, loss=8.08, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 107.28it/s, loss=8.08, v_num=6]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.08, v_num=6]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.08, v_num=6]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 82.67it/s, loss=7.99, v_num=6]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 98.37it/s, loss=7.95, v_num=6]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 120.05it/s, loss=7.95, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 108.81it/s, loss=7.95, v_num=6]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.95, v_num=6]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.95, v_num=6]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 82.61it/s, loss=7.76, v_num=6]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 98.70it/s, loss=7.84, v_num=6]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 120.70it/s, loss=8.28, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 109.14it/s, loss=8.28, v_num=6]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.28, v_num=6]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.28, v_num=6]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 81.45it/s, loss=9.01, v_num=6]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 96.56it/s, loss=9.53, v_num=6]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 114.91it/s, loss=10, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 105.60it/s, loss=10, v_num=6]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=10, v_num=6]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=10, v_num=6]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 83.06it/s, loss=10.4, v_num=6]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 99.20it/s, loss=10.4, v_num=6]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 121.33it/s, loss=10.3, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 109.68it/s, loss=10.3, v_num=6]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.3, v_num=6]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.3, v_num=6]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 82.40it/s, loss=10.3, v_num=6]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 98.19it/s, loss=10.4, v_num=6]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 119.78it/s, loss=10.5, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 108.46it/s, loss=10.5, v_num=6]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.5, v_num=6]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.5, v_num=6]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 80.30it/s, loss=10.8, v_num=6]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 95.92it/s, loss=11, v_num=6]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 117.65it/s, loss=11.3, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 107.03it/s, loss=11.3, v_num=6]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 74.42it/s, loss=11.3, v_num=6]

Deep CCA by Barlow Twins

barlowtwins = BarlowTwins(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
barlowtwins = CCALightning(barlowtwins)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)
plot_latent_label(dcca_sdl.model, train_loader)
plt.title("DCCA by Barlow Twins")
plt.show()
Dimension 0, DCCA by Barlow Twins

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/dev/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:413: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 75.83it/s, loss=-1.26, v_num=7]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 91.39it/s, loss=-1.28, v_num=7]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 111.25it/s, loss=-1.34, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 103.05it/s, loss=-1.34, v_num=7]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.34, v_num=7]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.34, v_num=7]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 77.57it/s, loss=-1.41, v_num=7]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 92.61it/s, loss=-1.48, v_num=7]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 111.99it/s, loss=-1.51, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 102.11it/s, loss=-1.51, v_num=7]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.51, v_num=7]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.51, v_num=7]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 77.38it/s, loss=-1.55, v_num=7]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 92.09it/s, loss=-1.57, v_num=7]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 111.81it/s, loss=-1.59, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 103.00it/s, loss=-1.59, v_num=7]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.59, v_num=7]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.59, v_num=7]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 76.31it/s, loss=-1.61, v_num=7]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 90.88it/s, loss=-1.63, v_num=7]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 109.92it/s, loss=-1.64, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 100.91it/s, loss=-1.64, v_num=7]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.64, v_num=7]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.64, v_num=7]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 77.37it/s, loss=-1.66, v_num=7]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 92.09it/s, loss=-1.67, v_num=7]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 111.70it/s, loss=-1.68, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 102.79it/s, loss=-1.68, v_num=7]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.68, v_num=7]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.68, v_num=7]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 76.69it/s, loss=-1.69, v_num=7]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 91.43it/s, loss=-1.7, v_num=7]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 111.49it/s, loss=-1.71, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 103.47it/s, loss=-1.71, v_num=7]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.71, v_num=7]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.71, v_num=7]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 80.92it/s, loss=-1.72, v_num=7]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 95.98it/s, loss=-1.73, v_num=7]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 116.39it/s, loss=-1.76, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 106.08it/s, loss=-1.76, v_num=7]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.76, v_num=7]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.76, v_num=7]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 77.95it/s, loss=-1.79, v_num=7]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 92.65it/s, loss=-1.82, v_num=7]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 112.03it/s, loss=-1.83, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 102.48it/s, loss=-1.83, v_num=7]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.83, v_num=7]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.83, v_num=7]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 78.86it/s, loss=-1.84, v_num=7]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 94.02it/s, loss=-1.85, v_num=7]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 114.12it/s, loss=-1.86, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 104.75it/s, loss=-1.86, v_num=7]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.86, v_num=7]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.86, v_num=7]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 78.61it/s, loss=-1.87, v_num=7]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 93.63it/s, loss=-1.88, v_num=7]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 113.58it/s, loss=-1.89, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 104.26it/s, loss=-1.89, v_num=7]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.89, v_num=7]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.89, v_num=7]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 78.97it/s, loss=-1.9, v_num=7]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 93.81it/s, loss=-1.9, v_num=7]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 114.03it/s, loss=-1.91, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 104.58it/s, loss=-1.91, v_num=7]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=7]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=7]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 76.78it/s, loss=-1.91, v_num=7]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 90.90it/s, loss=-1.92, v_num=7]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 110.25it/s, loss=-1.92, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 101.52it/s, loss=-1.92, v_num=7]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=7]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=7]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 78.24it/s, loss=-1.93, v_num=7]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 92.86it/s, loss=-1.93, v_num=7]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 112.95it/s, loss=-1.94, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 103.87it/s, loss=-1.94, v_num=7]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=7]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=7]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 78.08it/s, loss=-1.94, v_num=7]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 93.40it/s, loss=-1.94, v_num=7]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 113.51it/s, loss=-1.95, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 104.10it/s, loss=-1.95, v_num=7]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 79.64it/s, loss=-1.95, v_num=7]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 94.68it/s, loss=-1.95, v_num=7]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 114.79it/s, loss=-1.95, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 105.31it/s, loss=-1.95, v_num=7]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 77.81it/s, loss=-1.95, v_num=7]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 92.90it/s, loss=-1.96, v_num=7]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 112.54it/s, loss=-1.96, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 103.65it/s, loss=-1.96, v_num=7]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 79.94it/s, loss=-1.96, v_num=7]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 93.90it/s, loss=-1.96, v_num=7]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 110.75it/s, loss=-1.96, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 102.67it/s, loss=-1.96, v_num=7]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 79.07it/s, loss=-1.97, v_num=7]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 94.74it/s, loss=-1.97, v_num=7]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 113.36it/s, loss=-1.97, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 103.61it/s, loss=-1.97, v_num=7]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 79.80it/s, loss=-1.97, v_num=7]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 95.35it/s, loss=-1.97, v_num=7]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 115.60it/s, loss=-1.97, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 106.16it/s, loss=-1.97, v_num=7]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 78.57it/s, loss=-1.97, v_num=7]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 92.78it/s, loss=-1.97, v_num=7]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 111.58it/s, loss=-1.98, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 102.66it/s, loss=-1.98, v_num=7]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 72.51it/s, loss=-1.98, v_num=7]

Total running time of the script: ( 0 minutes 9.459 seconds)

Gallery generated by Sphinx-Gallery