Source code for cca_zoo.visualisation.explained_covariance

# Import the necessary modules
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import pandas as pd
import seaborn as sns

from cca_zoo._utils._checks import check_seaborn_support


[docs] class ExplainedCovarianceDisplay: """ Display the explained covariance of the latent variables of the representations. Parameters ---------- explained_covariance_train : np.ndarray The explained covariance of the train data. explained_covariance_test : np.ndarray The explained covariance of the test data. ratio : bool Whether to plot the ratio of explained covariance or not. **kwargs : dict Keyword arguments to be passed to the seaborn lineplot. Attributes ---------- figure_ : matplotlib.pyplot.figure The figure of the plot. Examples -------- >>> from cca_zoo.visualisation import ExplainedCovarianceDisplay >>> import matplotlib.pyplot as plt >>> import numpy as np >>> from cca_zoo.linear import MCCA >>> >>> # Generate Sample Data >>> # -------------------- >>> X = np.random.rand(100, 10) >>> Y = np.random.rand(100, 10) >>> >>> # Splitting the data into training and testing sets >>> X_train, X_test = X[:50], X[50:] >>> Y_train, Y_test = Y[:50], Y[50:] >>> >>> representations = [X_train, Y_train] >>> test_views = [X_test, Y_test] >>> >>> # Train an MCCA Model >>> # ------------------- >>> mcca = MCCA(latent_dimensions=2) >>> mcca.fit(representations) >>> >>> # %% >>> # Plotting the Explained Covariance >>> # --------------------------------- >>> ExplainedCovarianceDisplay.from_estimator(mcca, representations, test_views=test_views).plot() >>> plt.show() """ def __init__( self, explained_covariance_train, explained_covariance_test=None, ratio=True, **kwargs ): self.explained_covariance_train = explained_covariance_train self.explained_covariance_test = explained_covariance_test self.ratio = ratio self.kwargs = kwargs def _validate_plot_params(self): check_seaborn_support("CorrelationHeatmapDisplay") @classmethod def from_estimator(cls, model, train_views, test_views=None, ratio=True, **kwargs): # explained_covariance_train will be a numpy array of shape (latent_dimensions,len(train_views)) if ratio: explained_covariance_train = model.explained_covariance_ratio(train_views) else: explained_covariance_train = model.explained_covariance(train_views) if test_views is not None: if ratio: explained_covariance_test = model.explained_covariance_ratio(test_views) else: explained_covariance_test = model.explained_covariance(test_views) else: explained_covariance_test = None if ratio: return cls.from_explained_covariance_ratio( explained_covariance_train, explained_covariance_test, **kwargs ) else: return cls.from_explained_covariance( explained_covariance_train, explained_covariance_test, **kwargs ) @classmethod def from_explained_covariance( cls, explained_covariance_train, explained_covariance_test=None, **kwargs ): return cls( explained_covariance_train, explained_covariance_test, ratio=False, **kwargs ) @classmethod def from_explained_covariance_ratio( cls, explained_covariance_train, explained_covariance_test=None, **kwargs ): return cls( explained_covariance_train, explained_covariance_test, ratio=True, **kwargs ) def plot(self, ax=None): self._validate_plot_params() # Use seaborn lineplot with hue='Train' to plot the train and test data data = pd.DataFrame(self.explained_covariance_train, columns=["value"]) data["Mode"] = "Train" # Add a column indicating train data data.index.name = "Latent dimension" if self.explained_covariance_test is not None: data_test = pd.DataFrame(self.explained_covariance_test, columns=["value"]) data_test["Mode"] = "Test" # Add a column indicating test data data_test.index.name = "Latent dimension" data = pd.concat([data, data_test]) # Concatenate the two dataframes if ax is None: fig, ax = plt.subplots(figsize=(10, 5)) else: fig = ax.get_figure() sns.lineplot( data=data, x="Latent dimension", y="value", style="Mode", marker="o", ax=ax ) ax.set_xlabel("Latent dimension") if self.ratio: ax.set_ylabel("Explained covariance %") ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0)) else: ax.set_ylabel("Explained covariance") ax.set_title("Explained covariance") # Set x-ticks to integers ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) plt.tight_layout() self.figure_ = fig return self