from typing import Union, Iterable
import numpy as np
from sklearn.metrics import pairwise_kernels
from sklearn.neighbors import NearestNeighbors
from sklearn.utils.validation import check_is_fitted
from cca_zoo.models._cca_base import _CCA_Base
from cca_zoo.utils.check_values import _process_parameter, _check_views
[docs]class NCCA(_CCA_Base):
"""
A class used to fit nonparametric (NCCA) model.
:Citation:
Michaeli, Tomer, Weiran Wang, and Karen Livescu. "Nonparametric canonical correlation analysis." International conference on machine learning. PMLR, 2016.
:Example:
>>> from cca_zoo.models import NCCA
>>> X1 = np.random.rand(10,5)
>>> X2 = np.random.rand(10,5)
>>> model = NCCA()
>>> model.fit((X1,X2)).score((X1,X2))
array([1.])
"""
def __init__(
self,
latent_dims: int = 1,
scale=True,
centre=True,
copy_data=True,
accept_sparse=False,
random_state: Union[int, np.random.RandomState] = None,
nearest_neighbors=None,
gamma: Iterable[float] = None,
):
"""
Constructor for NCCA
:param latent_dims: number of latent dimensions to fit
:param scale: normalize variance in each column before fitting
:param centre: demean data by column before fitting (and before transforming out of sample
:param copy_data: If True, X will be copied; else, it may be overwritten
:param accept_sparse: Whether model can take sparse data as input
:param random_state: Pass for reproducible output across multiple function calls
:param nearest_neighbors: Number of nearest neighbors (l2 distance) to consider when constructing affinity
:param gamma: Bandwidth parameter for rbf kernel
"""
super().__init__(
latent_dims, scale, centre, copy_data, accept_sparse, random_state
)
self.nearest_neighbors = nearest_neighbors
self.gamma = gamma
def _check_params(self):
self.nearest_neighbors = _process_parameter(
"nearest_neighbors", self.nearest_neighbors, 1, self.n_views
)
self.gamma = _process_parameter("gamma", self.gamma, None, self.n_views)
self.kernel = _process_parameter("kernel", None, "rbf", self.n_views)
[docs] def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
views = _check_views(
*views, copy=self.copy_data, accept_sparse=self.accept_sparse
)
views = self._centre_scale(views)
self.n_views = len(views)
self.n = views[0].shape[0]
self._check_params()
self.train_views = views
self.knns = [
NearestNeighbors(n_neighbors=self.nearest_neighbors[i]).fit(view)
for i, view in enumerate(views)
]
NNs = [
self.knns[i].kneighbors(view, self.nearest_neighbors[i])
for i, view in enumerate(views)
]
kernels = [self._get_kernel(i, view) for i, view in enumerate(self.train_views)]
self.Ws = [fill_w(kernel, inds) for kernel, (dists, inds) in zip(kernels, NNs)]
self.Ws = [
self.Ws[0] / self.Ws[0].sum(axis=1, keepdims=True),
self.Ws[1] / self.Ws[1].sum(axis=0, keepdims=True),
]
S = self.Ws[0] @ self.Ws[1]
U, S, Vt = np.linalg.svd(S)
self.f = U[:, 1 : self.latent_dims + 1] * np.sqrt(self.n)
self.g = Vt[1 : self.latent_dims + 1, :].T * np.sqrt(self.n)
self.S = S[1 : self.latent_dims + 1]
return self
def _get_kernel(self, view, X, Y=None):
params = {
"gamma": self.gamma[view],
}
return pairwise_kernels(
X, Y, metric=self.kernel[view], filter_params=True, **params
)
def fill_w(kernels, inds):
w = np.zeros_like(kernels)
for i, ind in enumerate(inds):
w[ind, i] = kernels[ind, i]
return w.T