Source code for cca_zoo.deep._discriminative._dcca
import torch
from cca_zoo.deep._base import BaseDeep
from cca_zoo.deep.objectives import _CCALoss
from cca_zoo.linear._mcca import MCCA
[docs]
class DCCA(BaseDeep):
"""
A class used to fit a DCCA model.
References
----------
Andrew, Galen, et al. "Deep canonical correlation analysis." International conference on machine learning. PMLR, 2013.
"""
objective = _CCALoss()
def __init__(
self,
latent_dimensions: int,
encoders=None,
**kwargs,
):
super().__init__(latent_dimensions=latent_dimensions, **kwargs)
# Check if encoders are provided and have the same length as the number of representations
if encoders is None:
raise ValueError(
"Encoders must be a list of torch.nn.Module with length equal to the number of representations."
)
self.encoders = torch.nn.ModuleList(encoders)
[docs]
def forward(self, views, **kwargs):
if not hasattr(self, "n_views_"):
self.n_views_ = len(views)
# Use list comprehension to encode each view
z = [encoder(view) for encoder, view in zip(self.encoders, views)]
return z
[docs]
def loss(self, batch, **kwargs):
representations = self(batch["views"])
return {"objective": self.objective(representations)}
[docs]
def pairwise_correlations(self, loader: torch.utils.data.DataLoader):
# Call the parent class method
return super().pairwise_correlations(loader)
def correlation_captured(self, z):
# Remove mean from each view
z = [zi - zi.mean(0) for zi in z]
return MCCA(latent_dimensions=self.latent_dimensions).fit(z).score(z).sum()
[docs]
def score(self, loader: torch.utils.data.DataLoader, **kwargs):
z = self.transform(loader)
corr = self.correlation_captured(z)
return corr