Deep CCAΒΆ

This example demonstrates how to easily train Deep CCA models and variants

import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from torch.utils.data import Subset
from cca_zoo.data import Split_MNIST_Dataset
from cca_zoo.deepmodels import (
    DCCA,
    CCALightning,
    get_dataloaders,
    architectures,
    DCCA_NOI,
    DCCA_SDL,
    BarlowTwins,
)


def plot_latent_label(model, dataloader, num_batches=100):
    fig, ax = plt.subplots(ncols=model.latent_dims)
    for j in range(model.latent_dims):
        ax[j].set_title(f"Dimension {j}")
        ax[j].set_xlabel("View 1")
        ax[j].set_ylabel("View 2")
    for i, (data, label) in enumerate(dataloader):
        z = model(*data)
        zx, zy = z
        zx = zx.to("cpu").detach().numpy()
        zy = zy.to("cpu").detach().numpy()
        for j in range(model.latent_dims):
            ax[j].scatter(zx[:, j], zy[:, j], c=label.numpy(), cmap="tab10")
        if i > num_batches:
            plt.colorbar()
            break


n_train = 500
n_val = 100
train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True)
val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))
train_dataset = Subset(train_dataset, np.arange(n_train))
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset, batch_size=128)

# The number of latent dimensions across models
latent_dims = 2
# number of epochs for deep models
epochs = 20

encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)

Deep CCA

dcca = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
dcca = CCALightning(dcca)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)
plot_latent_label(dcca.model, train_loader)
plt.suptitle("DCCA")
plt.show()
DCCA, Dimension 0, Dimension 1

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 66.90it/s, loss=-0.28, v_num=4]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 78.97it/s, loss=-0.486, v_num=4]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 96.34it/s, loss=-0.738, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 90.58it/s, loss=-0.738, v_num=4]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=-0.738, v_num=4]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-0.738, v_num=4]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 64.86it/s, loss=-0.93, v_num=4]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 77.72it/s, loss=-1.06, v_num=4]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 94.67it/s, loss=-1.15, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 89.14it/s, loss=-1.15, v_num=4]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.15, v_num=4]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.15, v_num=4]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 65.71it/s, loss=-1.22, v_num=4]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 78.56it/s, loss=-1.27, v_num=4]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 95.40it/s, loss=-1.32, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 89.33it/s, loss=-1.32, v_num=4]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.32, v_num=4]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.32, v_num=4]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 66.33it/s, loss=-1.36, v_num=4]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 79.20it/s, loss=-1.39, v_num=4]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 96.15it/s, loss=-1.43, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 90.31it/s, loss=-1.43, v_num=4]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.43, v_num=4]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.43, v_num=4]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 67.21it/s, loss=-1.46, v_num=4]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 79.87it/s, loss=-1.47, v_num=4]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 97.03it/s, loss=-1.5, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 91.18it/s, loss=-1.5, v_num=4]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.5, v_num=4]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.5, v_num=4]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 67.40it/s, loss=-1.52, v_num=4]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 80.06it/s, loss=-1.54, v_num=4]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 97.28it/s, loss=-1.55, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 91.27it/s, loss=-1.55, v_num=4]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.55, v_num=4]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.55, v_num=4]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 65.12it/s, loss=-1.57, v_num=4]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 77.50it/s, loss=-1.58, v_num=4]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 94.36it/s, loss=-1.66, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 89.21it/s, loss=-1.66, v_num=4]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.66, v_num=4]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.66, v_num=4]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 68.40it/s, loss=-1.72, v_num=4]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 81.46it/s, loss=-1.75, v_num=4]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 98.90it/s, loss=-1.77, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 90.42it/s, loss=-1.77, v_num=4]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.77, v_num=4]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.77, v_num=4]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 67.28it/s, loss=-1.79, v_num=4]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 80.09it/s, loss=-1.81, v_num=4]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 97.31it/s, loss=-1.82, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 90.88it/s, loss=-1.82, v_num=4]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.82, v_num=4]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.82, v_num=4]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 67.51it/s, loss=-1.83, v_num=4]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 80.22it/s, loss=-1.84, v_num=4]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 97.72it/s, loss=-1.85, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 91.14it/s, loss=-1.85, v_num=4]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.85, v_num=4]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.85, v_num=4]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 67.47it/s, loss=-1.86, v_num=4]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 79.99it/s, loss=-1.87, v_num=4]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 97.01it/s, loss=-1.87, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 91.14it/s, loss=-1.87, v_num=4]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.87, v_num=4]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.87, v_num=4]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 65.94it/s, loss=-1.89, v_num=4]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 79.12it/s, loss=-1.89, v_num=4]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 96.31it/s, loss=-1.9, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 90.69it/s, loss=-1.9, v_num=4]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=4]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=4]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 67.85it/s, loss=-1.9, v_num=4]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 80.69it/s, loss=-1.91, v_num=4]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 98.04it/s, loss=-1.91, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 92.11it/s, loss=-1.91, v_num=4]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=4]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.91, v_num=4]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 68.06it/s, loss=-1.91, v_num=4]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 80.94it/s, loss=-1.92, v_num=4]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 98.22it/s, loss=-1.92, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 91.98it/s, loss=-1.92, v_num=4]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=4]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=4]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 68.27it/s, loss=-1.93, v_num=4]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 81.10it/s, loss=-1.93, v_num=4]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 98.45it/s, loss=-1.93, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 91.17it/s, loss=-1.93, v_num=4]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=4]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 67.66it/s, loss=-1.93, v_num=4]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 80.16it/s, loss=-1.94, v_num=4]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 97.49it/s, loss=-1.94, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 91.40it/s, loss=-1.94, v_num=4]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 67.56it/s, loss=-1.94, v_num=4]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 80.68it/s, loss=-1.94, v_num=4]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 97.93it/s, loss=-1.94, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 91.74it/s, loss=-1.94, v_num=4]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=4]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 66.31it/s, loss=-1.95, v_num=4]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 78.46it/s, loss=-1.95, v_num=4]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 95.70it/s, loss=-1.95, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 89.98it/s, loss=-1.95, v_num=4]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=4]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=4]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 67.62it/s, loss=-1.95, v_num=4]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 80.59it/s, loss=-1.95, v_num=4]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 98.06it/s, loss=-1.95, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 91.61it/s, loss=-1.95, v_num=4]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=4]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=4]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 66.91it/s, loss=-1.96, v_num=4]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 78.44it/s, loss=-1.96, v_num=4]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 95.55it/s, loss=-1.96, v_num=4]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 89.77it/s, loss=-1.96, v_num=4]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 63.77it/s, loss=-1.96, v_num=4]

Deep CCA by Non-Linear Orthogonal Iterations

dcca_noi = DCCA_NOI(
    latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2]
)
dcca_noi = CCALightning(dcca_noi)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca_noi, train_loader, val_loader)
plot_latent_label(dcca_noi.model, train_loader)
plt.title("DCCA by Non-Linear Orthogonal Iterations")
plt.show()
Dimension 0, DCCA by Non-Linear Orthogonal Iterations

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 71.29it/s, loss=38.9, v_num=5]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 85.97it/s, loss=27.6, v_num=5]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 106.15it/s, loss=21.6, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 96.29it/s, loss=21.6, v_num=5]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.6, v_num=5]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.6, v_num=5]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 69.21it/s, loss=18.2, v_num=5]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 84.40it/s, loss=16, v_num=5]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 104.79it/s, loss=14.5, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 95.86it/s, loss=14.5, v_num=5]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.5, v_num=5]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.5, v_num=5]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 71.32it/s, loss=13.3, v_num=5]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 84.35it/s, loss=12.1, v_num=5]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 104.04it/s, loss=11, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 95.23it/s, loss=11, v_num=5]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=11, v_num=5]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=11, v_num=5]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 71.40it/s, loss=10, v_num=5]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 83.76it/s, loss=9.27, v_num=5]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 103.82it/s, loss=8.63, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 95.07it/s, loss=8.63, v_num=5]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.63, v_num=5]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=8.63, v_num=5]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 71.68it/s, loss=8.11, v_num=5]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 86.46it/s, loss=7.67, v_num=5]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 106.67it/s, loss=7.3, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 97.11it/s, loss=7.3, v_num=5]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.3, v_num=5]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=7.3, v_num=5]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 70.56it/s, loss=6.96, v_num=5]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 85.68it/s, loss=6.64, v_num=5]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 106.11it/s, loss=6.35, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 96.80it/s, loss=6.35, v_num=5]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.35, v_num=5]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=6.35, v_num=5]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 69.78it/s, loss=6.08, v_num=5]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 84.88it/s, loss=5.83, v_num=5]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 105.04it/s, loss=3.92, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 96.06it/s, loss=3.92, v_num=5]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=3.92, v_num=5]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=3.92, v_num=5]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 71.43it/s, loss=3.16, v_num=5]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 85.90it/s, loss=2.73, v_num=5]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 105.68it/s, loss=2.37, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 95.88it/s, loss=2.37, v_num=5]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.37, v_num=5]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.37, v_num=5]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 68.26it/s, loss=2.08, v_num=5]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 83.49it/s, loss=1.75, v_num=5]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 102.81it/s, loss=1.5, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 94.77it/s, loss=1.5, v_num=5]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.5, v_num=5]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.5, v_num=5]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 71.82it/s, loss=1.35, v_num=5]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 86.52it/s, loss=1.27, v_num=5]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 106.76it/s, loss=1.21, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 96.87it/s, loss=1.21, v_num=5]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.21, v_num=5]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.21, v_num=5]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 71.25it/s, loss=1.16, v_num=5]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 85.56it/s, loss=1.11, v_num=5]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 105.52it/s, loss=1.04, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 96.49it/s, loss=1.04, v_num=5]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.04, v_num=5]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.04, v_num=5]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 71.56it/s, loss=0.96, v_num=5]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 85.92it/s, loss=0.884, v_num=5]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 106.09it/s, loss=0.807, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 96.61it/s, loss=0.807, v_num=5]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.807, v_num=5]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.807, v_num=5]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 72.45it/s, loss=0.751, v_num=5]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 87.49it/s, loss=0.698, v_num=5]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 107.84it/s, loss=0.66, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 98.02it/s, loss=0.66, v_num=5]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.66, v_num=5]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.66, v_num=5]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 73.08it/s, loss=0.624, v_num=5]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 87.67it/s, loss=0.597, v_num=5]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 108.20it/s, loss=0.565, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 98.34it/s, loss=0.565, v_num=5]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.565, v_num=5]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.565, v_num=5]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 72.17it/s, loss=0.536, v_num=5]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 86.69it/s, loss=0.492, v_num=5]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 106.99it/s, loss=0.453, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 97.27it/s, loss=0.453, v_num=5]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.453, v_num=5]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.453, v_num=5]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 70.82it/s, loss=0.424, v_num=5]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 85.29it/s, loss=0.401, v_num=5]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 105.16it/s, loss=0.374, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 95.89it/s, loss=0.374, v_num=5]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.374, v_num=5]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.374, v_num=5]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 70.28it/s, loss=0.354, v_num=5]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 84.45it/s, loss=0.344, v_num=5]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 104.45it/s, loss=0.329, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 95.91it/s, loss=0.329, v_num=5]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.329, v_num=5]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.329, v_num=5]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 71.08it/s, loss=0.317, v_num=5]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 85.85it/s, loss=0.3, v_num=5]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 106.17it/s, loss=0.285, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 97.07it/s, loss=0.285, v_num=5]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.285, v_num=5]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.285, v_num=5]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 71.42it/s, loss=0.269, v_num=5]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 86.23it/s, loss=0.259, v_num=5]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 106.74it/s, loss=0.247, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 97.42it/s, loss=0.247, v_num=5]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.247, v_num=5]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.247, v_num=5]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 72.36it/s, loss=0.24, v_num=5]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 86.83it/s, loss=0.227, v_num=5]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 106.74it/s, loss=0.219, v_num=5]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 97.12it/s, loss=0.219, v_num=5]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 67.23it/s, loss=0.219, v_num=5]

Deep CCA by Stochastic Decorrelation Loss

dcca_sdl = DCCA_SDL(
    latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2]
)
dcca_sdl = CCALightning(dcca_sdl)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca_sdl, train_loader, val_loader)
plot_latent_label(dcca_sdl.model, train_loader)
plt.title("DCCA by Stochastic Decorrelation")
plt.show()
Dimension 0, DCCA by Stochastic Decorrelation

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 65.89it/s, loss=29.1, v_num=6]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 81.53it/s, loss=28.7, v_num=6]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 102.15it/s, loss=29.8, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 95.54it/s, loss=29.8, v_num=6]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=29.8, v_num=6]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=29.8, v_num=6]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 72.34it/s, loss=31.5, v_num=6]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 87.04it/s, loss=32.7, v_num=6]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 107.54it/s, loss=31.7, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 98.94it/s, loss=31.7, v_num=6]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=31.7, v_num=6]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=31.7, v_num=6]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 72.59it/s, loss=29.8, v_num=6]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 87.33it/s, loss=28.2, v_num=6]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 107.96it/s, loss=25.9, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 99.00it/s, loss=25.9, v_num=6]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=25.9, v_num=6]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=25.9, v_num=6]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 71.87it/s, loss=24, v_num=6]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 86.95it/s, loss=23, v_num=6]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 107.81it/s, loss=22.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 99.00it/s, loss=22.6, v_num=6]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=22.6, v_num=6]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=22.6, v_num=6]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 71.92it/s, loss=22.3, v_num=6]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 86.95it/s, loss=22, v_num=6]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 107.69it/s, loss=21.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 99.13it/s, loss=21.6, v_num=6]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.6, v_num=6]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=21.6, v_num=6]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 72.38it/s, loss=21, v_num=6]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 86.88it/s, loss=20.1, v_num=6]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 107.27it/s, loss=19.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 98.49it/s, loss=19.6, v_num=6]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=19.6, v_num=6]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=19.6, v_num=6]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 72.65it/s, loss=19.3, v_num=6]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 87.93it/s, loss=19.3, v_num=6]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 108.97it/s, loss=18.8, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 99.91it/s, loss=18.8, v_num=6]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=18.8, v_num=6]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=18.8, v_num=6]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 70.86it/s, loss=18.1, v_num=6]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 84.99it/s, loss=17.1, v_num=6]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 105.09it/s, loss=15.9, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 96.71it/s, loss=15.9, v_num=6]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=15.9, v_num=6]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=15.9, v_num=6]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 73.12it/s, loss=14.5, v_num=6]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 87.92it/s, loss=13.8, v_num=6]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 108.93it/s, loss=13.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 99.77it/s, loss=13.6, v_num=6]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.6, v_num=6]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.6, v_num=6]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 73.78it/s, loss=13.5, v_num=6]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 88.95it/s, loss=13.9, v_num=6]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 109.97it/s, loss=14.2, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 99.89it/s, loss=14.2, v_num=6]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.2, v_num=6]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=14.2, v_num=6]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 72.09it/s, loss=14.2, v_num=6]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 86.77it/s, loss=13.7, v_num=6]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 107.21it/s, loss=13.2, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 98.40it/s, loss=13.2, v_num=6]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.2, v_num=6]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.2, v_num=6]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 71.54it/s, loss=12.7, v_num=6]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 85.69it/s, loss=12.4, v_num=6]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 106.56it/s, loss=12.3, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 98.30it/s, loss=12.3, v_num=6]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=12.3, v_num=6]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=12.3, v_num=6]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 73.29it/s, loss=13, v_num=6]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 88.17it/s, loss=13.2, v_num=6]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 109.03it/s, loss=13.7, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 99.61it/s, loss=13.7, v_num=6]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.7, v_num=6]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.7, v_num=6]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 71.42it/s, loss=13.8, v_num=6]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 86.51it/s, loss=13.5, v_num=6]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 107.30it/s, loss=13.3, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 98.62it/s, loss=13.3, v_num=6]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.3, v_num=6]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=13.3, v_num=6]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 72.82it/s, loss=13.2, v_num=6]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 88.01it/s, loss=12.7, v_num=6]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 107.90it/s, loss=12.6, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 99.41it/s, loss=12.6, v_num=6]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=12.6, v_num=6]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=12.6, v_num=6]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 73.71it/s, loss=12.5, v_num=6]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 88.27it/s, loss=12.1, v_num=6]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 109.12it/s, loss=11.8, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 99.87it/s, loss=11.8, v_num=6]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=11.8, v_num=6]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=11.8, v_num=6]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 72.31it/s, loss=11.2, v_num=6]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 87.29it/s, loss=11, v_num=6]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 107.84it/s, loss=10.7, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 98.71it/s, loss=10.7, v_num=6]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.7, v_num=6]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.7, v_num=6]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 72.25it/s, loss=10.6, v_num=6]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 87.20it/s, loss=10.5, v_num=6]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 107.88it/s, loss=10.3, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 99.11it/s, loss=10.3, v_num=6]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.3, v_num=6]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=10.3, v_num=6]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 72.40it/s, loss=10.1, v_num=6]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 87.40it/s, loss=9.88, v_num=6]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 108.22it/s, loss=9.21, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 97.06it/s, loss=9.21, v_num=6]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.21, v_num=6]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=9.21, v_num=6]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 71.92it/s, loss=8.66, v_num=6]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 83.88it/s, loss=7.76, v_num=6]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 103.86it/s, loss=6.95, v_num=6]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 95.70it/s, loss=6.95, v_num=6]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 65.68it/s, loss=6.95, v_num=6]

Deep CCA by Barlow Twins

barlowtwins = BarlowTwins(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
barlowtwins = CCALightning(barlowtwins)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)
plot_latent_label(dcca_sdl.model, train_loader)
plt.title("DCCA by Barlow Twins")
plt.show()
Dimension 0, DCCA by Barlow Twins

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]
Epoch 0:  25%|##5       | 1/4 [00:00<00:00, 66.42it/s, loss=-1.17, v_num=7]
Epoch 0:  50%|#####     | 2/4 [00:00<00:00, 78.40it/s, loss=-1.3, v_num=7]
Epoch 0:  75%|#######5  | 3/4 [00:00<00:00, 95.01it/s, loss=-1.38, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 4/4 [00:00<00:00, 88.84it/s, loss=-1.38, v_num=7]


Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.38, v_num=7]
Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.38, v_num=7]
Epoch 1:  25%|##5       | 1/4 [00:00<00:00, 63.85it/s, loss=-1.47, v_num=7]
Epoch 1:  50%|#####     | 2/4 [00:00<00:00, 76.99it/s, loss=-1.51, v_num=7]
Epoch 1:  75%|#######5  | 3/4 [00:00<00:00, 93.73it/s, loss=-1.56, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 4/4 [00:00<00:00, 88.71it/s, loss=-1.56, v_num=7]


Epoch 1:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.56, v_num=7]
Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.56, v_num=7]
Epoch 2:  25%|##5       | 1/4 [00:00<00:00, 66.12it/s, loss=-1.59, v_num=7]
Epoch 2:  50%|#####     | 2/4 [00:00<00:00, 78.16it/s, loss=-1.62, v_num=7]
Epoch 2:  75%|#######5  | 3/4 [00:00<00:00, 94.07it/s, loss=-1.64, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 4/4 [00:00<00:00, 88.55it/s, loss=-1.64, v_num=7]


Epoch 2:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.64, v_num=7]
Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.64, v_num=7]
Epoch 3:  25%|##5       | 1/4 [00:00<00:00, 61.90it/s, loss=-1.66, v_num=7]
Epoch 3:  50%|#####     | 2/4 [00:00<00:00, 75.47it/s, loss=-1.68, v_num=7]
Epoch 3:  75%|#######5  | 3/4 [00:00<00:00, 92.23it/s, loss=-1.7, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 4/4 [00:00<00:00, 87.68it/s, loss=-1.7, v_num=7]


Epoch 3:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.7, v_num=7]
Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.7, v_num=7]
Epoch 4:  25%|##5       | 1/4 [00:00<00:00, 64.48it/s, loss=-1.71, v_num=7]
Epoch 4:  50%|#####     | 2/4 [00:00<00:00, 76.95it/s, loss=-1.72, v_num=7]
Epoch 4:  75%|#######5  | 3/4 [00:00<00:00, 93.52it/s, loss=-1.73, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 4/4 [00:00<00:00, 88.35it/s, loss=-1.73, v_num=7]


Epoch 4:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.73, v_num=7]
Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.73, v_num=7]
Epoch 5:  25%|##5       | 1/4 [00:00<00:00, 66.39it/s, loss=-1.74, v_num=7]
Epoch 5:  50%|#####     | 2/4 [00:00<00:00, 78.48it/s, loss=-1.75, v_num=7]
Epoch 5:  75%|#######5  | 3/4 [00:00<00:00, 95.24it/s, loss=-1.76, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 4/4 [00:00<00:00, 89.79it/s, loss=-1.76, v_num=7]


Epoch 5:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.76, v_num=7]
Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.76, v_num=7]
Epoch 6:  25%|##5       | 1/4 [00:00<00:00, 66.47it/s, loss=-1.77, v_num=7]
Epoch 6:  50%|#####     | 2/4 [00:00<00:00, 79.34it/s, loss=-1.78, v_num=7]
Epoch 6:  75%|#######5  | 3/4 [00:00<00:00, 95.76it/s, loss=-1.81, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 4/4 [00:00<00:00, 89.73it/s, loss=-1.81, v_num=7]


Epoch 6:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.81, v_num=7]
Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.81, v_num=7]
Epoch 7:  25%|##5       | 1/4 [00:00<00:00, 66.62it/s, loss=-1.84, v_num=7]
Epoch 7:  50%|#####     | 2/4 [00:00<00:00, 79.18it/s, loss=-1.86, v_num=7]
Epoch 7:  75%|#######5  | 3/4 [00:00<00:00, 95.92it/s, loss=-1.87, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 4/4 [00:00<00:00, 90.59it/s, loss=-1.87, v_num=7]


Epoch 7:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.87, v_num=7]
Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.87, v_num=7]
Epoch 8:  25%|##5       | 1/4 [00:00<00:00, 66.98it/s, loss=-1.88, v_num=7]
Epoch 8:  50%|#####     | 2/4 [00:00<00:00, 79.57it/s, loss=-1.89, v_num=7]
Epoch 8:  75%|#######5  | 3/4 [00:00<00:00, 96.66it/s, loss=-1.9, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 4/4 [00:00<00:00, 90.88it/s, loss=-1.9, v_num=7]


Epoch 8:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=7]
Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.9, v_num=7]
Epoch 9:  25%|##5       | 1/4 [00:00<00:00, 65.75it/s, loss=-1.91, v_num=7]
Epoch 9:  50%|#####     | 2/4 [00:00<00:00, 78.66it/s, loss=-1.91, v_num=7]
Epoch 9:  75%|#######5  | 3/4 [00:00<00:00, 95.31it/s, loss=-1.92, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 4/4 [00:00<00:00, 89.76it/s, loss=-1.92, v_num=7]


Epoch 9:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=7]
Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.92, v_num=7]
Epoch 10:  25%|##5       | 1/4 [00:00<00:00, 65.78it/s, loss=-1.92, v_num=7]
Epoch 10:  50%|#####     | 2/4 [00:00<00:00, 78.09it/s, loss=-1.92, v_num=7]
Epoch 10:  75%|#######5  | 3/4 [00:00<00:00, 94.91it/s, loss=-1.93, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|##########| 4/4 [00:00<00:00, 89.68it/s, loss=-1.93, v_num=7]


Epoch 10:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=7]
Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.93, v_num=7]
Epoch 11:  25%|##5       | 1/4 [00:00<00:00, 63.78it/s, loss=-1.93, v_num=7]
Epoch 11:  50%|#####     | 2/4 [00:00<00:00, 76.15it/s, loss=-1.94, v_num=7]
Epoch 11:  75%|#######5  | 3/4 [00:00<00:00, 92.71it/s, loss=-1.94, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|##########| 4/4 [00:00<00:00, 87.89it/s, loss=-1.94, v_num=7]


Epoch 11:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=7]
Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.94, v_num=7]
Epoch 12:  25%|##5       | 1/4 [00:00<00:00, 63.47it/s, loss=-1.94, v_num=7]
Epoch 12:  50%|#####     | 2/4 [00:00<00:00, 75.17it/s, loss=-1.94, v_num=7]
Epoch 12:  75%|#######5  | 3/4 [00:00<00:00, 91.42it/s, loss=-1.95, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|##########| 4/4 [00:00<00:00, 86.48it/s, loss=-1.95, v_num=7]


Epoch 12:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 13:  25%|##5       | 1/4 [00:00<00:00, 66.15it/s, loss=-1.95, v_num=7]
Epoch 13:  50%|#####     | 2/4 [00:00<00:00, 78.66it/s, loss=-1.95, v_num=7]
Epoch 13:  75%|#######5  | 3/4 [00:00<00:00, 95.59it/s, loss=-1.95, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|##########| 4/4 [00:00<00:00, 90.04it/s, loss=-1.95, v_num=7]


Epoch 13:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.95, v_num=7]
Epoch 14:  25%|##5       | 1/4 [00:00<00:00, 65.30it/s, loss=-1.95, v_num=7]
Epoch 14:  50%|#####     | 2/4 [00:00<00:00, 78.56it/s, loss=-1.96, v_num=7]
Epoch 14:  75%|#######5  | 3/4 [00:00<00:00, 95.78it/s, loss=-1.96, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|##########| 4/4 [00:00<00:00, 89.01it/s, loss=-1.96, v_num=7]


Epoch 14:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 15:  25%|##5       | 1/4 [00:00<00:00, 66.99it/s, loss=-1.96, v_num=7]
Epoch 15:  50%|#####     | 2/4 [00:00<00:00, 79.77it/s, loss=-1.96, v_num=7]
Epoch 15:  75%|#######5  | 3/4 [00:00<00:00, 96.81it/s, loss=-1.96, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|##########| 4/4 [00:00<00:00, 91.16it/s, loss=-1.96, v_num=7]


Epoch 15:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.96, v_num=7]
Epoch 16:  25%|##5       | 1/4 [00:00<00:00, 67.02it/s, loss=-1.96, v_num=7]
Epoch 16:  50%|#####     | 2/4 [00:00<00:00, 79.82it/s, loss=-1.96, v_num=7]
Epoch 16:  75%|#######5  | 3/4 [00:00<00:00, 96.97it/s, loss=-1.97, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|##########| 4/4 [00:00<00:00, 91.26it/s, loss=-1.97, v_num=7]


Epoch 16:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 17:  25%|##5       | 1/4 [00:00<00:00, 67.13it/s, loss=-1.97, v_num=7]
Epoch 17:  50%|#####     | 2/4 [00:00<00:00, 80.43it/s, loss=-1.97, v_num=7]
Epoch 17:  75%|#######5  | 3/4 [00:00<00:00, 97.61it/s, loss=-1.97, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|##########| 4/4 [00:00<00:00, 90.98it/s, loss=-1.97, v_num=7]


Epoch 17:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 18:  25%|##5       | 1/4 [00:00<00:00, 66.81it/s, loss=-1.97, v_num=7]
Epoch 18:  50%|#####     | 2/4 [00:00<00:00, 78.84it/s, loss=-1.97, v_num=7]
Epoch 18:  75%|#######5  | 3/4 [00:00<00:00, 95.29it/s, loss=-1.97, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|##########| 4/4 [00:00<00:00, 89.77it/s, loss=-1.97, v_num=7]


Epoch 18:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 19:   0%|          | 0/4 [00:00<?, ?it/s, loss=-1.97, v_num=7]
Epoch 19:  25%|##5       | 1/4 [00:00<00:00, 67.70it/s, loss=-1.97, v_num=7]
Epoch 19:  50%|#####     | 2/4 [00:00<00:00, 79.19it/s, loss=-1.97, v_num=7]
Epoch 19:  75%|#######5  | 3/4 [00:00<00:00, 96.60it/s, loss=-1.97, v_num=7]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|##########| 4/4 [00:00<00:00, 90.88it/s, loss=-1.97, v_num=7]


Epoch 19: 100%|##########| 4/4 [00:00<00:00, 64.62it/s, loss=-1.97, v_num=7]

Total running time of the script: ( 0 minutes 10.707 seconds)

Gallery generated by Sphinx-Gallery