Source code for cca_zoo.deep._discriminative._barlow_twins

import torch

from ._dcca import DCCA


[docs] class BarlowTwins(DCCA): """ A class used to fit a Barlow Twins model. Barlow Twins is a self-supervised learning method that applies redundancy-reduction to learn representations that are invariant to distortions of the input sample. Parameters ---------- lamb : float, optional off-diagonal scaling factor for the cross-covariance matrix. Defaults to 5e-3. References ---------- Zbontar, Jure, et al. "Barlow twins: Self-supervised learning via redundancy reduction." arXiv preprint arXiv:2103.03230 (2021). """ def __init__( self, latent_dimensions: int, encoders=None, lamb=5e-3, **kwargs, ): super().__init__( latent_dimensions=latent_dimensions, encoders=encoders, **kwargs ) self.lamb = lamb # the lambda parameter for the off-diagonal terms of the cross-covariance matrix self.bns = torch.nn.ModuleList( [ torch.nn.BatchNorm1d(latent_dimensions, affine=False) for _ in self.encoders ] ) # a list of batch normalization layers for each encoder
[docs] def forward(self, views, **kwargs): z = [] for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)): z.append(bn(encoder(views[i]))) # encode and normalize each view return z # return a list of normalized latent representations
[docs] def loss(self, batch, **kwargs): z = self(batch["views"]) # get the latent representations cross_cov = ( z[0].T @ z[1] / z[0].shape[0] ) # compute the cross-covariance matrix between the two representations invariance = torch.sum( torch.pow(1 - torch.diag(cross_cov), 2) ) # compute the invariance term as the sum of squared differences from 1 on the diagonal covariance = torch.sum( torch.triu(torch.pow(cross_cov, 2), diagonal=1) ) + torch.sum( torch.tril(torch.pow(cross_cov, 2), diagonal=-1) ) # compute the covariance term as the sum of squared values on the off-diagonal return { "objective": invariance + self.lamb * covariance, # return the objective value as a combination of invariance and covariance terms "invariance": invariance, # return the invariance term "covariance": covariance, # return the covariance term }