Source code for cca_zoo.model_selection._search
import itertools
from typing import Iterable
import numpy as np
from sklearn import clone
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.model_selection import ParameterGrid
from sklearn.model_selection._search import (
BaseSearchCV,
)
from sklearn.model_selection._search import ParameterSampler
from sklearn.pipeline import Pipeline
from sklearn.utils import check_random_state
from cca_zoo._utils._splitter import SimpleSplitter
def param2grid(params):
params = {
k: list(v) if (hasattr(v, "__iter__") and not isinstance(v, str)) else [v]
for k, v in params.items()
}
for k, v in params.items():
if any([hasattr(v_, "__iter__") and not isinstance(v_, str) for v_ in v]):
params[k] = list(map(list, itertools.product(*v)))
return ParameterGrid(params)
class ParameterSampler_(ParameterSampler):
def __iter__(self):
rng = check_random_state(self.random_state)
for _ in range(self.n_iter):
dist = rng.choice(self.param_distributions)
# Always sort the keys of a dictionary, for reproducibility
items = sorted(dist.items())
params = dict()
for k, v in items:
# if v is iterable, then list comprehension else v
if isinstance(v, Iterable) and not isinstance(v, str):
# use list comprehension to handle different types of values
params[k] = (
[self.return_param(v_) for v_ in v]
if isinstance(v, Iterable)
else self.return_param(v)
)
else:
params[k] = self.return_param(v)
yield params
def return_param(self, v):
rng = check_random_state(self.random_state)
# use ternary operator to handle different types of values
param = (
v.rvs(random_state=rng)
if hasattr(v, "rvs")
else (
v[rng.randint(len(v))]
if isinstance(v, Iterable) and not isinstance(v, str)
else v
)
)
return param
def __len__(self):
"""Number of points that will be sampled."""
return self.n_iter
class BaseSearchCV(BaseSearchCV):
def fit(self, views, y=None, *, groups=None, **fit_params):
self.estimator = Pipeline(
[
("splitter", SimpleSplitter([view.shape[1] for view in views])),
("estimator", clone(self.estimator)),
]
)
super().fit(np.hstack(views), y=y, groups=groups, **fit_params)
self.estimator = self.estimator[1]
self.best_estimator_ = self.best_estimator_[1]
self.best_params_ = {
key.split("estimator__")[1]: val for key, val in self.best_params_.items()
}
return self
[docs]
class GridSearchCV(GridSearchCV, BaseSearchCV):
def _run_search(self, evaluate_candidates):
"""Search all candidates in param_grid"""
if not isinstance(self.param_grid, ParameterGrid):
param_grid = param2grid(self.param_grid)
else:
param_grid = self.param_grid
param_grid.param_grid = [
{f"estimator__{key}": val for key, val in subgrid.items()}
for subgrid in param_grid.param_grid
]
evaluate_candidates(param_grid)
[docs]
class RandomizedSearchCV(RandomizedSearchCV, BaseSearchCV):
def _run_search(self, evaluate_candidates):
self.param_distributions = {
f"estimator__{key}": val for key, val in self.param_distributions.items()
}
"""Search n_iter candidates from param_distributions"""
evaluate_candidates(
ParameterSampler_(
self.param_distributions, self.n_iter, random_state=self.random_state
)
)