import itertools
from abc import abstractmethod
from typing import Union, Iterable
import numpy as np
from scipy.sparse import issparse
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
from sklearn.utils.sparsefuncs import mean_variance_axis
from sklearn.utils.validation import check_random_state, check_is_fitted
from cca_zoo.utils.check_values import _check_views
from cca_zoo.utils.plotting import plot_latent_train_test
[docs]class _CCA_Base(BaseEstimator, MultiOutputMixin, RegressorMixin):
"""
A class used as the base for methods in the package. Allows methods to inherit fit_transform, predict_corr,
and gridsearch_fit when only fit (and transform where it is different to the default) is provided.
Attributes
----------
weights : list of weights for each view
"""
def __init__(
self,
latent_dims: int = 1,
scale=True,
centre=True,
copy_data=True,
accept_sparse=False,
random_state: Union[int, np.random.RandomState] = None,
):
"""
Constructor for _CCA_Base
:param latent_dims: number of latent dimensions to fit
:param scale: normalize variance in each column before fitting
:param centre: demean data by column before fitting (and before transforming out of sample
:param copy_data: If True, X will be copied; else, it may be overwritten
:param accept_sparse: Whether model can take sparse data as input
:param random_state: Pass for reproducible output across multiple function calls
"""
self.latent_dims = latent_dims
self.scale = scale
self.centre = centre
self.copy_data = copy_data
self.accept_sparse = accept_sparse
self.random_state = check_random_state(random_state)
self.n_views = None
[docs] @abstractmethod
def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
"""
Fits a given model
:param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
"""
raise NotImplementedError
[docs] def get_loadings(self, views: Iterable[np.ndarray], y=None, **kwargs):
"""
Returns the model loadings for each view for the given data
:param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
:param kwargs: any additional keyword arguments required by the given model
"""
transformed_views = self.transform(views, **kwargs)
views = self._centre_scale_transform(views)
loadings = [
view.T @ transformed_view
for view, transformed_view in zip(views, transformed_views)
]
return loadings
[docs] def correlations(self, views: Iterable[np.ndarray], y=None, **kwargs):
"""
Predicts the correlation for the given data using the fit model
:param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
:param kwargs: any additional keyword arguments required by the given model
:return: all_corrs: an array of the pairwise correlations (k,k,self.latent_dims) where k is the number of views
:rtype: np.ndarray
"""
transformed_views = self.transform(views, **kwargs)
all_corrs = []
for x, y in itertools.product(transformed_views, repeat=2):
all_corrs.append(
np.diag(np.corrcoef(x.T, y.T)[: self.latent_dims, self.latent_dims :])
)
all_corrs = np.array(all_corrs).reshape(
(len(views), len(views), self.latent_dims)
)
return all_corrs
def plot_latent(
self,
views: Iterable[np.ndarray],
test_views: Iterable[np.ndarray] = None,
title="",
):
scores = self.transform(views)
if test_views is not None:
test_scores = self.transform(test_views)
else:
test_scores = None
plot_latent_train_test(scores, test_scores, title=title)
[docs] def score(self, views: Iterable[np.ndarray], y=None, **kwargs):
# by default return the average pairwise correlation in each dimension (for 2 views just the correlation)
pair_corrs = self.correlations(views, **kwargs)
# n views
n_views = pair_corrs.shape[0]
# sum all the pairwise correlations for each dimension. Subtract the self correlations. Divide by the number of views. Gives average correlation
dim_corrs = (
pair_corrs.sum(axis=tuple(range(pair_corrs.ndim - 1))) - n_views
) / (n_views ** 2 - n_views)
return dim_corrs
[docs] def _centre_scale(self, views: Iterable[np.ndarray]):
"""
Removes the mean of the training data and standardizes for each view and stores mean and standard deviation during training
:param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
:return: train_views: the demeaned numpy arrays to be used to fit the model
"""
self.view_means = []
self.view_stds = []
transformed_views = []
for view in views:
if issparse(view):
view_mean, view_std = mean_variance_axis(view, axis=0)
self.view_means.append(view_mean)
self.view_stds.append(view_std)
view = view - self.view_means[-1]
view = view / self.view_stds[-1]
else:
if self.centre:
view_mean = view.mean(axis=0)
self.view_means.append(view_mean)
view = view - self.view_means[-1]
if self.scale:
view_std = view.std(axis=0, ddof=1)
view_std[view_std == 0.0] = 1.0
self.view_stds.append(view_std)
view = view / self.view_stds[-1]
transformed_views.append(view)
return transformed_views