Source code for cca_zoo.sequential

"""
Module for finding _CCALoss effects sequentially by deflation.

Check if each effect is significant, and if so, remove it from the data and repeat.
"""
from abc import ABCMeta
from typing import Iterable

import numpy as np
from sklearn.base import MetaEstimatorMixin

from cca_zoo._base import _BaseModel
from cca_zoo.linear._iterative._deflation import deflate_views
from cca_zoo.model_selection._validation import permutation_test_score


[docs] class SequentialModel(MetaEstimatorMixin, _BaseModel, metaclass=ABCMeta): def __init__( self, estimator, estimator_hyperparams=None, permutation_test_params=None, latent_dimensions=None, # Maximum number of latent dimensions to fit copy_data=True, accept_sparse=False, random_state=None, permutation_test=False, # Whether to use permutation test to determine significance p_threshold=1e-3, # Threshold for permutation test if used corr_threshold=0.0, # Threshold for effect size if permutation test not used ): super().__init__( latent_dimensions=latent_dimensions, copy_data=copy_data, accept_sparse=accept_sparse, random_state=random_state, ) # Check the estimator has 1 latent dimension or if it is GridSearchCV or RandomizedSearchCV that the base # estimator has 1 latent dimension if hasattr(estimator, "estimator"): if estimator.estimator.latent_dimensions != 1: raise ValueError( "The estimator must have 1 latent dimension, but has {}".format( estimator.estimator.latent_dimensions ) ) elif estimator.latent_dimensions != 1: raise ValueError( "The estimator must have 1 latent dimension, but has {}".format( estimator.latent_dimensions ) ) self.estimator = estimator if estimator_hyperparams is None: estimator_hyperparams = {} self.estimator_hyperparams = estimator_hyperparams self.permutation_test = permutation_test if permutation_test_params is None: permutation_test_params = {} self.permutation_test_params = permutation_test_params self.p_threshold = p_threshold self.corr_threshold = corr_threshold
[docs] def fit(self, views: Iterable[np.ndarray], y=None, **kwargs): # Validate the input data and parameters self._validate_data(views) self._check_params() # Set the default latent dimensions to the minimum number of features if self.latent_dimensions is None: self.latent_dimensions = min([view.shape[1] for view in views]) # Initialize the weights_ and p-values lists self.weights_ = [[] for view in views] self.p_values = [] # Loop over the latent dimensions k = 0 while k < self.latent_dimensions: # Fit the estimator with the current representations self.estimator.set_params(**self.estimator_hyperparams) self.estimator.fit(views) # Perform permutation test if required p_value = None best_estimator = self.estimator if self.permutation_test: # Get the best estimator if it exists, otherwise use the original estimator best_estimator = getattr( self.estimator, "best_estimator_", self.estimator ) # Get the p-value from the permutation test score p_value = permutation_test_score( best_estimator, views, y=None, **self.permutation_test_params, )[2] # Append the p-value to the list self.p_values.append(p_value) # Check if the stopping criterion is met based on p-value or correlation score if ( p_value is not None and p_value >= self.p_threshold ) or best_estimator.score(views) < self.corr_threshold: if p_value is not None: self.p_values.pop() break else: # Deflate the representations and store the weights_ views = deflate_views(views, best_estimator.weights_) for i, weight in enumerate(best_estimator.weights_): self.weights_[i].append(weight) k += 1 # Safety check to ensure the loop hasn't resulted in empty weights if all(len(w) == 0 for w in self.weights_): raise ValueError("No significant latent dimensions found.") # Set the final latent dimensions to k self.latent_dimensions = k # Concatenate the weights_ from each effect self.weights_ = [np.concatenate(weights, axis=1) for weights in self.weights_] return self