Source code for cca_zoo.linear._partialcca

from typing import Iterable

import numpy as np
from sklearn.utils.validation import check_is_fitted

from cca_zoo.linear._mcca import MCCA


[docs] class PartialCCA(MCCA): r""" A class used to fit a partial CCA model. This model extends CCA to account for confounding variables that may affect the correlation between representations. .. math:: w_{opt}=\underset{w}{\mathrm{argmax}}\{ w_1^TX_1^TX_2w_2 \}\\ \text{subject to:} w_i^TX_i^TX_iw_i=1 w_i^TX_i^TZ=0 Example ------- >>> from cca_zoo.linear import PartialCCA >>> X1 = np.random.rand(10,5) >>> X2 = np.random.rand(10,5) >>> partials = np.random.rand(10,3) >>> model = PartialCCA() >>> model.fit((X1,X2),partials=partials).score((X1,X2)) References ---------- Rao, B. Raja. "Partial canonical correlations." Trabajos de estadistica y de investigación operativa 20.2-3 (1969): 211-219. """
[docs] def fit(self, views: Iterable[np.ndarray], y=None, partials=None, **kwargs): self.pca = False return super().fit( views, y=y, partials=partials, **kwargs ) # call the parent class fit method
def _process_data(self, views, partials=None, **kwargs): if partials is None: return super()._process_data(views, **kwargs) else: self.confound_betas = [ np.linalg.pinv(partials) @ view for view in views ] # compute the confounding betas for each view using pseudo-inverse of partials views = [ view - partials
[docs] @ np.linalg.pinv(partials) @ view # remove the confounding effect from each view using projection matrix for view, confound_beta in zip(views, self.confound_betas) ] return views def transform(self, views: Iterable[np.ndarray], partials=None, **kwargs): if partials is None: return super().transform(views, **kwargs) else: check_is_fitted( self ) # check if the model has been fitted before transforming transformed_views = [] for i, (view) in enumerate(views): transformed_view = ( view - partials @ self.confound_betas[i] # remove the confounding effect from each view using stored confounding betas ) @ self.weights_[ i ] # multiply each view by its corresponding weight matrix transformed_views.append( transformed_view ) # append the transformed view to the list of transformed representations return transformed_views # return the list of transformed representations
def _more_tags(self): return {"multiview": True} # indicate that this model can handle multiview data