import torch
from cca_zoo.deepmodels import objectives
from cca_zoo.deepmodels.architectures import Encoder
from cca_zoo.models import MCCA
from ._dcca_base import _DCCA_base
[docs]class DCCA(_DCCA_base):
"""
A class used to fit a DCCA model.
:Citation:
Andrew, Galen, et al. "Deep canonical correlation analysis." International conference on machine learning. PMLR, 2013.
"""
def __init__(
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-3,
):
"""
Constructor class for DCCA
:param latent_dims: # latent dimensions
:param objective: # CCA objective: normal tracenorm CCA by default
:param encoders: list of encoder networks
:param r: regularisation parameter of tracenorm CCA like ridge CCA. Needs to be VERY SMALL. If you get errors make this smaller
:param eps: epsilon used throughout. Needs to be VERY SMALL. If you get errors make this smaller
"""
super().__init__(latent_dims=latent_dims)
if encoders is None:
encoders = [Encoder, Encoder]
self.encoders = torch.nn.ModuleList(encoders)
self.objective = objective(latent_dims, r=r, eps=eps)
[docs] def forward(self, *args):
z = []
for i, encoder in enumerate(self.encoders):
z.append(encoder(args[i]))
return z
[docs] def loss(self, *args):
"""
Define the loss function for the model. This is used by the DeepWrapper class
:param args:
:return:
"""
z = self(*args)
return self.objective.loss(*z)
[docs] def post_transform(self, z_list, train=False):
if train:
self.cca = MCCA(latent_dims=self.latent_dims)
z_list = self.cca.fit_transform(z_list)
else:
z_list = self.cca.transform(z_list)
return z_list