Source code for cca_zoo.deep._discriminative._dcca_sdl
import torch
import torch.nn.functional as F
from ._dcca import DCCA
def sdl_loss(view):
"""Calculate SDL loss."""
cov = torch.cov(view.T)
sgn = torch.sign(cov)
sgn.fill_diagonal_(0)
return torch.mean(cov * sgn)
[docs]
class DCCA_SDL(DCCA):
"""
A class used to fit a Deep _CCALoss by Stochastic Decorrelation model.
References
----------
Chang, Xiaobin, Tao Xiang, and Timothy M. Hospedales. "Scalable and effective deep _CCALoss via soft decorrelation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.
"""
def __init__(
self,
latent_dimensions: int,
encoders=None,
r: float = 0,
eps: float = 1e-5,
shared_target: bool = False,
lam=0.5,
**kwargs
):
super().__init__(
latent_dimensions=latent_dimensions,
encoders=encoders,
r=r,
eps=eps,
shared_target=shared_target,
**kwargs
)
self.c = None
self.lam = lam
self.bns = torch.nn.ModuleList(
[
torch.nn.BatchNorm1d(latent_dimensions, affine=False)
for _ in self.encoders
]
)
[docs]
def forward(self, views, **kwargs):
z = []
for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)):
z.append(bn(encoder(views[i])))
return z
[docs]
def loss(self, batch, **kwargs):
z = self(batch["views"])
l2_loss = F.mse_loss(z[0], z[1])
SDL_loss = torch.sum(torch.stack([sdl_loss(z_) for z_ in z]))
loss = l2_loss + self.lam * SDL_loss
return {"objective": loss, "l2": l2_loss, "sdl": SDL_loss}
def _sdl_loss(self, covs):
loss = 0
for cov in covs:
sgn = torch.sign(cov)
sgn.fill_diagonal_(0)
loss += torch.mean(cov * sgn)
return loss