Source code for cca_zoo.deep._discriminative._dcca

import torch

from cca_zoo.deep import objectives
from cca_zoo.deep._base import BaseDeep
from cca_zoo.deep.metrics import MCCA


[docs]class DCCA(BaseDeep): """ A class used to fit a DCCA model. References ---------- Andrew, Galen, et al. "Deep canonical correlation analysis." International conference on machine learning. PMLR, 2013. """ def __init__( self, latent_dimensions: int, objective=objectives.MCCA, encoders=None, r: float = 0, eps: float = 1e-5, **kwargs, ): super().__init__(latent_dimensions=latent_dimensions, **kwargs) # Check if encoders are provided and have the same length as the number of views if encoders is None: raise ValueError( "Encoders must be a list of torch.nn.Module with length equal to the number of views." ) self.encoders = torch.nn.ModuleList(encoders) self.objective = objective(latent_dimensions, r=r, eps=eps) self.correlation = MCCA()
[docs] def forward(self, views, **kwargs): if not hasattr(self, "n_views_"): self.n_views_ = len(views) # Use list comprehension to encode each view z = [encoder(view) for encoder, view in zip(self.encoders, views)] return z
[docs] def loss(self, views, **kwargs): z = self(views) return {"objective": self.objective.loss(z)}
[docs] def pairwise_correlations(self, loader: torch.utils.data.DataLoader): # Call the parent class method return super().pairwise_correlations(loader)
[docs] def score(self, loader: torch.utils.data.DataLoader, **kwargs): z = self.get_representations(loader) corr = self.correlation(z) return corr.numpy()