Source code for cca_zoo.linear._iterative._spls

import itertools
import warnings
from typing import Union, Iterable

import numpy as np

from cca_zoo._utils._checks import _process_parameter
from cca_zoo._utils._cross_correlation import cross_corrcoef
from cca_zoo.linear._iterative._base import _BaseIterative
from cca_zoo.linear._iterative._deflation import _DeflationMixin
from cca_zoo.linear._pls import PLSMixin
from cca_zoo.linear._search import _delta_search

[docs] class SPLS(_DeflationMixin, _BaseIterative, PLSMixin): def __init__( self, latent_dimensions: int = 1, copy_data=True, random_state=None, tol=1e-3, accept_sparse=None, epochs=100, initialization: Union[str, callable] = "pls", early_stopping=False, verbose=True, tau=None, # regularization parameter for PMD positive=False, ): super().__init__( latent_dimensions=latent_dimensions, copy_data=copy_data, random_state=random_state, tol=tol, accept_sparse=accept_sparse, epochs=epochs, initialization=initialization, early_stopping=early_stopping, verbose=verbose, ) self.tau = tau self.positive = positive def _check_params(self): if self.tau is None: warnings.warn( "tau parameter not set. Setting to tau=1 i.e. maximum regularisation of l1 norm" ) self.tau = _process_parameter("tau", self.tau, 1, self.n_views_) if any(tau < 0 or tau > 1 for tau in self.tau): raise ValueError( "All regularisation parameters should be between 0 and 1 " f"1. tau=[{self.tau}]" ) self.positive = _process_parameter( "positive", self.positive, False, self.n_views_ ) def _update_weights(self, views: np.ndarray, i: int): if not hasattr(self, "t"): shape_sqrts = [np.sqrt(weight.shape[0]) for weight in self.weights_] self.t = [max(1, x * y) for x, y in zip(self.tau, shape_sqrts)] # Update the weights_ for the current view using PMD # Get the scores of all representations scores = np.stack(self.transform(views)) # Create a mask that is True for elements not equal to i along dim i mask = np.arange(scores.shape[0]) != i # Apply the mask to scores and sum along dim i target = np.sum(scores[mask], axis=0) # Compute the new weights_ by multiplying the view with the target new_weights = views[i].T @ target if self.positive[i]: # If positive is true, set all negative values to 0 new_weights[new_weights < 0] = 0 # Apply the delta search function to the new weights_ with the regularization parameter new_weights = _delta_search(new_weights, self.t[i], tol=self.tol) # Return the new weights_ return new_weights def _objective(self, views: Iterable[np.ndarray]): # Compute the objective function value for a given set of representations using SCCA # Get the scores of all representations transformed_views = self.transform(views) all_covs = [] # Sum all the pairwise covariances except self-covariance for x, y in itertools.product(transformed_views, repeat=2): all_covs.append(np.diag(cross_corrcoef(x.T, y.T))) # the sum of covariances return np.sum(all_covs) - np.sum( [ self.tau[i] * np.linalg.norm(self.weights_[i]) for i in range(len(self.weights_)) ] ) def _more_tags(self): # Indicate that this class is for multiview data return {"multiview": True}