Source code for cca_zoo.deep._generative._dccae
import torch
from .. import objectives
from .._discriminative._dcca import DCCA
from .._generative._base import _GenerativeMixin
[docs]
class DCCAE(DCCA, _GenerativeMixin):
"""
A class used to fit a DCCAE model.
References
----------
Wang, Weiran, et al. "On deep multi-view representation learning." International conference on machine learning. PMLR, 2015.
"""
def __init__(
self,
latent_dimensions: int,
objective=objectives._MCCALoss,
encoders=None,
decoders=None,
eps: float = 1e-5,
lam=0.5,
latent_dropout=0,
img_dim=None,
recon_loss_type="mse",
**kwargs,
):
super().__init__(
latent_dimensions=latent_dimensions,
objective=objective,
encoders=encoders,
eps=eps,
**kwargs,
)
self.img_dim = img_dim
self.decoders = torch.nn.ModuleList(decoders)
if lam < 0 or lam > 1:
raise ValueError(f"lam should be between 0 and 1. rho={lam}")
self.lam = lam
self.objective = objective(eps=eps)
self.latent_dropout = torch.nn.Dropout(p=latent_dropout)
self.recon_loss_type = recon_loss_type
[docs]
def forward(self, views, **kwargs):
"""
Forward method for the model. Outputs latent encoding for each view
:param views:
:param kwargs:
:return:
"""
z = []
for i, encoder in enumerate(self.encoders):
z.append(encoder(views[i]))
return z
def _decode(self, z, **kwargs):
"""
This method is used to decode from the latent space to the best prediction of the original representations
"""
recon = []
for i, decoder in enumerate(self.decoders):
recon.append(decoder(self.latent_dropout(z[i])))
return recon
[docs]
def loss(self, batch, **kwargs):
z = self(batch["views"])
recons = self._decode(z)
loss = dict()
loss["reconstruction"] = torch.stack(
[
self.recon_loss(x, recon, loss_type=self.recon_loss_type)
for x, recon in zip(batch["views"], recons)
]
).sum()
loss["correlation"] = self.objective(z)
loss["objective"] = (
self.lam * loss["reconstruction"] + (1 - self.lam) * loss["correlation"]
)
return loss