Source code for cca_zoo.deepmodels.dvcca

from typing import Iterable

import torch
import torch.distributions as dist
from torch.nn import functional as F

from cca_zoo.deepmodels.architectures import BaseEncoder, Encoder, Decoder
from cca_zoo.deepmodels.dcca import _DCCA_base


[docs]class DVCCA(_DCCA_base): """ A class used to fit a DVCCA model. :Citation: Wang, Weiran, et al. "Deep variational canonical correlation analysis." arXiv preprint arXiv:1610.03454 (2016). https: // arxiv.org / pdf / 1610.03454.pdf https: // github.com / pytorch / examples / blob / master / vae / main.py """ def __init__( self, latent_dims: int, encoders=None, decoders=None, private_encoders: Iterable[BaseEncoder] = None, ): """ :param latent_dims: # latent dimensions :param encoders: list of encoder networks :param decoders: list of decoder networks :param private_encoders: list of private (view specific) encoder networks """ super().__init__(latent_dims=latent_dims) if decoders is None: decoders = [Decoder, Decoder] if encoders is None: encoders = [Encoder, Encoder] self.encoders = torch.nn.ModuleList(encoders) self.decoders = torch.nn.ModuleList(decoders) if private_encoders: self.private_encoders = torch.nn.ModuleList(private_encoders) else: self.private_encoders = None
[docs] def forward(self, *args, mle=True): """ :param args: :param mle: :return: """ # Used when we get reconstructions mu, logvar = self._encode(*args) if mle: z = mu else: z_dist = dist.Normal(mu, torch.exp(0.5 * logvar)) z = z_dist.rsample() # If using single encoder repeat representation n times if len(self.encoders) == 1: z = z * len(args) if self.private_encoders: mu_p, logvar_p = self._encode_private(*args) if mle: z_p = mu_p else: z_dist = dist.Normal(mu_p, torch.exp(0.5 * logvar_p)) z_p = z_dist.rsample() z = [torch.cat((z_, z_p_), dim=-1) for z_, z_p_ in zip(z, z_p)] return z
def _encode(self, *args): """ :param args: :return: """ mu = [] logvar = [] for i, encoder in enumerate(self.encoders): mu_i, logvar_i = encoder(args[i]) mu.append(mu_i) logvar.append(logvar_i) return mu, logvar def _encode_private(self, *args): """ :param args: :return: """ mu = [] logvar = [] for i, private_encoder in enumerate(self.private_encoders): mu_i, logvar_i = private_encoder(args[i]) mu.append(mu_i) logvar.append(logvar_i) return mu, logvar def _decode(self, z): """ :param z: :return: """ x = [] for i, decoder in enumerate(self.decoders): x_i = F.sigmoid(decoder(z)) x.append(x_i) return x
[docs] def recon(self, *args): """ :param args: :return: """ z = self(*args) return [self._decode(z_) for z_ in z]
[docs] def loss(self, *args): """ :param args: :return: """ mus, logvars = self._encode(*args) if self.private_encoders: mus_p, logvars_p = self._encode_private(*args) losses = [ self.vcca_private_loss( *args, mu=mu, logvar=logvar, mu_p=mu_p, logvar_p=logvar_p ) for (mu, logvar, mu_p, logvar_p) in zip(mus, logvars, mus_p, logvars_p) ] else: losses = [ self.vcca_loss(*args, mu=mu, logvar=logvar) for (mu, logvar) in zip(mus, logvars) ] return torch.stack(losses).mean()
[docs] def vcca_loss(self, *args, mu, logvar): """ :param args: :param mu: :param logvar: :return: """ batch_n = mu.shape[0] z_dist = dist.Normal(mu, torch.exp(0.5 * logvar)) z = z_dist.rsample() kl = torch.mean( -0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0 ) recons = self._decode(z) bces = torch.stack( [ F.binary_cross_entropy(recon, arg, reduction="mean") for recon, arg in zip(recons, args) ] ).sum() return kl + bces
[docs] def vcca_private_loss(self, *args, mu, logvar, mu_p, logvar_p): """ :param args: :param mu: :param logvar: :return: """ batch_n = mu.shape[0] z_dist = dist.Normal(mu, torch.exp(0.5 * logvar)) z = z_dist.rsample() z_p_dist = dist.Normal(mu_p, torch.exp(0.5 * logvar_p)) z_p = z_p_dist.rsample() kl_p = torch.stack( [ torch.mean( -0.5 * torch.sum(1 + logvar_p - logvar_p.exp() - mu_p.pow(2), dim=1), dim=0, ) for i, _ in enumerate(self.private_encoders) ] ).sum() kl = torch.mean( -0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0 ) z_combined = torch.cat([z, z_p], dim=-1) recon = self._decode(z_combined) bces = torch.stack( [ F.binary_cross_entropy(recon[i], args[i], reduction="sum") / batch_n for i, _ in enumerate(self.decoders) ] ).sum() return kl + kl_p + bces