Source code for cca_zoo.deep._generative._splitae

import torch

from .._base import BaseDeep
from .._generative._base import _GenerativeMixin
from ..architectures import Encoder


[docs] class SplitAE(BaseDeep, _GenerativeMixin): """ A class used to fit a Split Autoencoder model. References ---------- Ngiam, Jiquan, et al. "Multimodal deep learning." ICML. 2011. """ def __init__( self, latent_dimensions: int, encoder=Encoder, decoders=None, latent_dropout=0, recon_loss_type="mse", img_dim=None, **kwargs ): """ :param latent_dimensions: # latent dimensions :param encoder: list of encoder networks :param decoders: list of decoder networks """ super().__init__(latent_dimensions=latent_dimensions, **kwargs) self.img_dim = img_dim self.encoder = encoder self.decoders = torch.nn.ModuleList(decoders) self.latent_dropout = torch.nn.Dropout(p=latent_dropout) self.recon_loss_type = recon_loss_type
[docs] def forward(self, views, **kwargs): """ Forward method for the model. Outputs latent encoding for each view :param views: :param kwargs: :return: """ z = [] z.append(self.encoder(views[0])) return z
def _decode(self, z, **kwargs): """ This method is used to decode from the latent space to the best prediction of the original representations :param z: """ recon = [] for i, decoder in enumerate(self.decoders): recon.append(decoder(self.latent_dropout(z[0]))) return recon
[docs] def loss(self, batch, **kwargs): z = self(batch["views"]) recons = self._decode(z) loss = dict() loss["reconstruction"] = torch.stack( [ self.recon_loss(x, recon, loss_type=self.recon_loss_type) for x, recon in zip(batch["views"], recons) ] ).sum() loss["objective"] = loss["reconstruction"] return loss