Source code for cca_zoo.deep._discriminative._dcca_barlow_twins

import torch

from ._dcca import DCCA


[docs]class BarlowTwins(DCCA): """ A class used to fit a Barlow Twins model. Barlow Twins is a self-supervised learning method that applies redundancy-reduction to learn representations that are invariant to distortions of the input sample. References ---------- Zbontar, Jure, et al. "Barlow twins: Self-supervised learning via redundancy reduction." arXiv preprint arXiv:2103.03230 (2021). """ def __init__( self, latent_dimensions: int, encoders=None, lam=1, **kwargs, ): super().__init__( latent_dimensions=latent_dimensions, encoders=encoders, **kwargs ) self.lam = lam # the lambda parameter for the off-diagonal terms of the cross-covariance matrix self.bns = torch.nn.ModuleList( [ torch.nn.BatchNorm1d(latent_dimensions, affine=False) for _ in self.encoders ] ) # a list of batch normalization layers for each encoder
[docs] def forward(self, views, **kwargs): z = [] for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)): z.append(bn(encoder(views[i]))) # encode and normalize each view return z # return a list of normalized latent representations
[docs] def loss(self, views, **kwargs): z = self(views) # get the latent representations cross_cov = ( z[0].T @ z[1] / z[0].shape[0] ) # compute the cross-covariance matrix between the two views invariance = torch.sum( torch.pow(1 - torch.diag(cross_cov), 2) ) # compute the invariance term as the sum of squared differences from 1 on the diagonal covariance = torch.sum( torch.triu(torch.pow(cross_cov, 2), diagonal=1) ) + torch.sum( torch.tril(torch.pow(cross_cov, 2), diagonal=-1) ) # compute the covariance term as the sum of squared values on the off-diagonal return { "objective": invariance + self.lam * covariance, # return the objective value as a combination of invariance and covariance terms "invariance": invariance, # return the invariance term "covariance": covariance, # return the covariance term }