Deep CCA for more than 2 viewsΒΆ

This example demonstrates how to easily train Deep CCA models and variants

import numpy as np
import pytorch_lightning as pl
from torch.utils.data import Subset
from cca_zoo.data import Split_MNIST_Dataset
from cca_zoo.deepmodels import (
    DCCA,
    CCALightning,
    get_dataloaders,
    architectures,
    objectives,
    DTCCA,
)

n_train = 500
n_val = 100
train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True)
val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))
train_dataset = Subset(train_dataset, np.arange(n_train))
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset)

# The number of latent dimensions across models
latent_dims = 2
# number of epochs for deep models
epochs = 10

encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)

Deep MCCA

dcca = DCCA(
    latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.MCCA
)
dcca = CCALightning(dcca)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:  50%|#####     | 1/2 [00:00<00:00, 40.61it/s, loss=-0.339, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 2/2 [00:00<00:00, 52.66it/s, loss=-0.339, v_num=1]


Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.339, v_num=1]
Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.339, v_num=1]
Epoch 1:  50%|#####     | 1/2 [00:00<00:00, 40.73it/s, loss=-0.774, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 2/2 [00:00<00:00, 52.90it/s, loss=-0.774, v_num=1]


Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.774, v_num=1]
Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=-0.774, v_num=1]
Epoch 2:  50%|#####     | 1/2 [00:00<00:00, 40.49it/s, loss=-1.01, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 2/2 [00:00<00:00, 52.86it/s, loss=-1.01, v_num=1]


Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.01, v_num=1]
Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.01, v_num=1]
Epoch 3:  50%|#####     | 1/2 [00:00<00:00, 41.02it/s, loss=-1.16, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 2/2 [00:00<00:00, 53.00it/s, loss=-1.16, v_num=1]


Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.16, v_num=1]
Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.16, v_num=1]
Epoch 4:  50%|#####     | 1/2 [00:00<00:00, 37.68it/s, loss=-1.27, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 2/2 [00:00<00:00, 47.56it/s, loss=-1.27, v_num=1]


Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.27, v_num=1]
Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.27, v_num=1]
Epoch 5:  50%|#####     | 1/2 [00:00<00:00, 40.39it/s, loss=-1.35, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 2/2 [00:00<00:00, 52.49it/s, loss=-1.35, v_num=1]


Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.35, v_num=1]
Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.35, v_num=1]
Epoch 6:  50%|#####     | 1/2 [00:00<00:00, 39.87it/s, loss=-1.41, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 2/2 [00:00<00:00, 52.03it/s, loss=-1.41, v_num=1]


Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.41, v_num=1]
Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.41, v_num=1]
Epoch 7:  50%|#####     | 1/2 [00:00<00:00, 39.57it/s, loss=-1.46, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 2/2 [00:00<00:00, 51.95it/s, loss=-1.46, v_num=1]


Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.46, v_num=1]
Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.46, v_num=1]
Epoch 8:  50%|#####     | 1/2 [00:00<00:00, 41.39it/s, loss=-1.5, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 2/2 [00:00<00:00, 53.55it/s, loss=-1.5, v_num=1]


Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.5, v_num=1]
Epoch 9:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.5, v_num=1]
Epoch 9:  50%|#####     | 1/2 [00:00<00:00, 40.51it/s, loss=-1.53, v_num=1]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 2/2 [00:00<00:00, 52.74it/s, loss=-1.53, v_num=1]


Epoch 9: 100%|##########| 2/2 [00:00<00:00, 33.66it/s, loss=-1.53, v_num=1]

Deep GCCA

dcca = DCCA(
    latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.GCCA
)
dcca = CCALightning(dcca)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:  50%|#####     | 1/2 [00:00<00:00, 24.58it/s, loss=-1.85, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 2/2 [00:00<00:00, 36.92it/s, loss=-1.85, v_num=2]


Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.85, v_num=2]
Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.85, v_num=2]
Epoch 1:  50%|#####     | 1/2 [00:00<00:00, 24.07it/s, loss=-1.86, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 2/2 [00:00<00:00, 36.12it/s, loss=-1.86, v_num=2]


Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.86, v_num=2]
Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.86, v_num=2]
Epoch 2:  50%|#####     | 1/2 [00:00<00:00, 24.69it/s, loss=-1.88, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 2/2 [00:00<00:00, 36.80it/s, loss=-1.88, v_num=2]


Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.88, v_num=2]
Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.88, v_num=2]
Epoch 3:  50%|#####     | 1/2 [00:00<00:00, 24.96it/s, loss=-1.89, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 2/2 [00:00<00:00, 37.20it/s, loss=-1.89, v_num=2]


Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.89, v_num=2]
Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.89, v_num=2]
Epoch 4:  50%|#####     | 1/2 [00:00<00:00, 24.80it/s, loss=-1.9, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 2/2 [00:00<00:00, 37.06it/s, loss=-1.9, v_num=2]


Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.9, v_num=2]
Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.9, v_num=2]
Epoch 5:  50%|#####     | 1/2 [00:00<00:00, 25.08it/s, loss=-1.91, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 2/2 [00:00<00:00, 37.41it/s, loss=-1.91, v_num=2]


Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.91, v_num=2]
Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.91, v_num=2]
Epoch 6:  50%|#####     | 1/2 [00:00<00:00, 25.03it/s, loss=-1.91, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 2/2 [00:00<00:00, 37.19it/s, loss=-1.91, v_num=2]


Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.91, v_num=2]
Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.91, v_num=2]
Epoch 7:  50%|#####     | 1/2 [00:00<00:00, 25.24it/s, loss=-1.92, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 2/2 [00:00<00:00, 37.67it/s, loss=-1.92, v_num=2]


Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.92, v_num=2]
Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.92, v_num=2]
Epoch 8:  50%|#####     | 1/2 [00:00<00:00, 24.69it/s, loss=-1.92, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 2/2 [00:00<00:00, 36.81it/s, loss=-1.92, v_num=2]


Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.92, v_num=2]
Epoch 9:   0%|          | 0/2 [00:00<?, ?it/s, loss=-1.92, v_num=2]
Epoch 9:  50%|#####     | 1/2 [00:00<00:00, 25.20it/s, loss=-1.93, v_num=2]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 2/2 [00:00<00:00, 37.56it/s, loss=-1.93, v_num=2]


Epoch 9: 100%|##########| 2/2 [00:00<00:00, 26.61it/s, loss=-1.93, v_num=2]

Deep TCCA

dcca = DTCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
dcca = CCALightning(dcca)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)

Out:

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'


/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:408: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s] /home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 0:  50%|#####     | 1/2 [00:00<00:00, 19.69it/s, loss=8.06e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 2/2 [00:00<00:00, 25.24it/s, loss=8.06e-08, v_num=3]


Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.06e-08, v_num=3]
Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.06e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 1:  50%|#####     | 1/2 [00:00<00:00, 19.62it/s, loss=8.24e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|##########| 2/2 [00:00<00:00, 24.89it/s, loss=8.24e-08, v_num=3]


Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.24e-08, v_num=3]
Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.24e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 2:  50%|#####     | 1/2 [00:00<00:00,  7.58it/s, loss=8.24e-08, v_num=3]
Epoch 2:  50%|#####     | 1/2 [00:00<00:00,  7.55it/s, loss=8.48e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|##########| 2/2 [00:00<00:00, 13.40it/s, loss=8.48e-08, v_num=3]


Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.48e-08, v_num=3]
Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.48e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 3:  50%|#####     | 1/2 [00:00<00:00, 20.01it/s, loss=8.46e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|##########| 2/2 [00:00<00:00, 24.89it/s, loss=8.46e-08, v_num=3]


Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.46e-08, v_num=3]
Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=8.46e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 4:  50%|#####     | 1/2 [00:00<00:00, 19.26it/s, loss=7.71e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|##########| 2/2 [00:00<00:00, 25.14it/s, loss=7.71e-08, v_num=3]


Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.71e-08, v_num=3]
Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.71e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 5:  50%|#####     | 1/2 [00:00<00:00, 20.05it/s, loss=7.69e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|##########| 2/2 [00:00<00:00, 13.49it/s, loss=7.69e-08, v_num=3]


Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.69e-08, v_num=3]
Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.69e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 6:  50%|#####     | 1/2 [00:00<00:00, 20.05it/s, loss=7.96e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|##########| 2/2 [00:00<00:00, 25.69it/s, loss=7.96e-08, v_num=3]


Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.96e-08, v_num=3]
Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.96e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 7:  50%|#####     | 1/2 [00:00<00:00, 19.40it/s, loss=7.78e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|##########| 2/2 [00:00<00:00, 25.29it/s, loss=7.78e-08, v_num=3]


Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.78e-08, v_num=3]
Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.78e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 8:  50%|#####     | 1/2 [00:00<00:00, 19.86it/s, loss=7.15e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|##########| 2/2 [00:00<00:00, 25.58it/s, loss=7.15e-08, v_num=3]


Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.15e-08, v_num=3]
Epoch 9:   0%|          | 0/2 [00:00<?, ?it/s, loss=7.15e-08, v_num=3]/home/docs/checkouts/readthedocs.org/user_builds/cca-zoo/envs/v1.10.4/lib/python3.7/site-packages/tensorly/backend/core.py:1106: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.
  warnings.warn('In partial_svd: converting to NumPy.'

Epoch 9:  50%|#####     | 1/2 [00:00<00:00, 20.77it/s, loss=7.18e-08, v_num=3]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|##########| 2/2 [00:00<00:00, 25.92it/s, loss=7.18e-08, v_num=3]


Epoch 9: 100%|##########| 2/2 [00:00<00:00, 17.36it/s, loss=7.18e-08, v_num=3]

Total running time of the script: ( 0 minutes 6.284 seconds)

Gallery generated by Sphinx-Gallery