Source code for cca_zoo.deep._discriminative._dcca_sdl

import torch
import torch.nn.functional as F

from ._dcca_noi import DCCA_NOI


[docs]class DCCA_SDL(DCCA_NOI): """ A class used to fit a Deep CCA by Stochastic Decorrelation model. References ---------- Chang, Xiaobin, Tao Xiang, and Timothy M. Hospedales. "Scalable and effective deep CCA via soft decorrelation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018. """ def __init__( self, latent_dimensions: int, N: int, encoders=None, r: float = 0, rho: float = 0.2, eps: float = 1e-5, shared_target: bool = False, lam=0.5, **kwargs ): super().__init__( latent_dimensions=latent_dimensions, N=N, encoders=encoders, r=r, rho=rho, eps=eps, shared_target=shared_target, **kwargs ) self.c = None self.cross_cov = None self.lam = lam self.bns = torch.nn.ModuleList( [ torch.nn.BatchNorm1d(latent_dimensions, affine=False) for _ in self.encoders ] )
[docs] def forward(self, views, **kwargs): z = [] for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)): z.append(bn(encoder(views[i]))) return z
[docs] def loss(self, views, **kwargs): z = self(views) l2_loss = F.mse_loss(z[0], z[1]) self._update_covariances(z, train=self.training) SDL_loss = self._sdl_loss(self.covs) loss = l2_loss + self.lam * SDL_loss self.covs = [cov.detach() for cov in self.covs] return {"objective": loss, "l2": l2_loss, "sdl": SDL_loss}
def _sdl_loss(self, covs): loss = 0 for cov in covs: sgn = torch.sign(cov) sgn.fill_diagonal_(0) loss += torch.mean(cov * sgn) return loss def _update_covariances(self, z, train=True): b = z[0].shape[0] batch_covs = [self.N * z_.T @ z_ / b for z_ in z] if train: if self.covs is not None: self.covs = [ self.rho * self.covs[i] + (1 - self.rho) * batch_cov for i, batch_cov in enumerate(batch_covs) ] else: self.covs = batch_covs # pytorch-lightning runs validation once so this just fixes the bug elif self.covs is None: self.covs = batch_covs