Source code for cca_zoo.deep._discriminative._dcca_noi

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module

from ._dcca import DCCA


class BatchWhiten(Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        track_running_stats: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super(BatchWhiten, self).__init__()

        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.track_running_stats = track_running_stats

        if self.track_running_stats:
            self.register_buffer(
                "running_covar", torch.eye(num_features, **factory_kwargs)
            )
            self.running_covar: Optional[Tensor]
            self.register_buffer(
                "num_batches_tracked",
                torch.tensor(
                    0,
                    dtype=torch.long,
                    **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
                ),
            )
            self.num_batches_tracked: Optional[Tensor]
        else:
            self.register_buffer("running_covar", None)
            self.register_buffer("num_batches_tracked", None)

        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            # fill with identity to preserve initialization
            self.running_covar.fill_diagonal_(1)
            self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator]

    def reset_parameters(self):
        self.reset_running_stats()

    def forward(self, input: Tensor) -> Tensor:
        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked.add_(1)  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
        """
        if self.training:
            bn_training = True
        else:
            bn_training = self.running_covar is None

        running_covar = (
            self.running_covar
            if not self.training or self.track_running_stats
            else None
        )

        # Calculate batch covariance
        covar = torch.matmul(input.T, input) / input.shape[0]

        # Update running covariance
        if bn_training:
            with torch.no_grad():
                if running_covar is not None:
                    running_covar.mul_(exponential_average_factor).add_(
                        covar, alpha=1 - exponential_average_factor
                    )

                # Calculate whitened input
                if running_covar is not None:
                    covar = running_covar
                # Enforce positive definite by taking a torch max() with eps
                covar = torch.max(covar, torch.tensor(self.eps, device=covar.device))
                # Calculate inverse square-root matrix
                B = inv_sqrtm(covar, self.eps)
                # Calculate whitened input
                input = torch.matmul(input, B)
                return input
        else:
            return input


def inv_sqrtm(A, eps=1e-9):
    """Compute the inverse square-root of a positive definite matrix."""
    # Perform eigendecomposition of covariance matrix
    U, S, V = torch.svd(A)
    # Enforce positive definite by taking a torch max() with eps
    S = torch.max(S, torch.tensor(eps, device=S.device))
    # Calculate inverse square-root
    inv_sqrt_S = torch.diag_embed(torch.pow(S, -0.5))
    # Calculate inverse square-root matrix
    B = torch.matmul(torch.matmul(U, inv_sqrt_S), V.transpose(-1, -2))
    return B


[docs] class DCCA_NOI(DCCA): """ A class used to fit a DCCA model by non-linear orthogonal iterations References ---------- Wang, Weiran, et al. "Stochastic optimization for deep CCA via nonlinear orthogonal iterations." 2015 53rd Annual Allerton Conference on Communication, Control, and Computing (Allerton). IEEE, 2015. """ def __init__( self, latent_dimensions: int, encoders=None, r: float = 0, rho: float = 0.1, eps: float = 1e-9, **kwargs, ): super().__init__( latent_dimensions=latent_dimensions, encoders=encoders, r=r, eps=eps, **kwargs, ) if rho < 0 or rho > 1: raise ValueError(f"rho should be between 0 and 1. rho={rho}") self.eps = eps self.rho = rho self.mse = torch.nn.MSELoss(reduction="sum") # Replace BatchNorm1d with BatchWhiten self.bws = torch.nn.ModuleList( [BatchWhiten(latent_dimensions, momentum=rho) for _ in self.encoders] )
[docs] def loss(self, batch, **kwargs): z = self(batch["views"]) z_w = [bw(z_) for z_, bw in zip(z, self.bws)] loss = self.mse(z[0], z_w[1].detach()) + self.mse(z[1], z_w[0].detach()) return {"objective": loss}