Source code for cca_zoo.visualisation.covariance
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from cca_zoo._utils._checks import check_seaborn_support
[docs]
class CovarianceHeatmapDisplay:
"""Covariance Heatmap Display
Heatmap of the covariances between the latent variables of the representations.
Parameters
----------
train_covariances : np.ndarray
The train covariances between representations.
test_covariances : np.ndarray
The test covariances between representations.
Attributes
----------
figure_ : matplotlib.pyplot.figure
The figure of the plot.
Examples
--------
>>> from cca_zoo.visualisation import CovarianceHeatmapDisplay
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from cca_zoo.linear import MCCA
>>>
>>> # Generate Sample Data
>>> # --------------------
>>> X = np.random.rand(100, 10)
>>> Y = np.random.rand(100, 10)
>>>
>>> # Splitting the data into training and testing sets
>>> X_train, X_test = X[:50], X[50:]
>>> Y_train, Y_test = Y[:50], Y[50:]
>>>
>>> representations = [X_train, Y_train]
>>> test_views = [X_test, Y_test]
>>>
>>> # Train an MCCA Model
>>> # -------------------
>>> mcca = MCCA(latent_dimensions=2)
>>> mcca.fit(representations)
>>>
>>> # %%
>>> # Plotting the Covariance Heatmap
>>> # -------------------------------
>>> CovarianceHeatmapDisplay.from_estimator(mcca, representations, test_views=test_views).plot()
>>> plt.show()
"""
def __init__(self, train_covariances, test_covariances):
self.train_covariances = train_covariances
self.test_covariances = test_covariances
def _validate_plot_params(self):
check_seaborn_support("CorrelationHeatmapDisplay")
@classmethod
def from_estimator(cls, model, train_views, test_views=None):
train_scores = model.transform(train_views)
if test_views is not None:
test_scores = model.transform(test_views)
else:
test_scores = None
train_covariances = np.cov(train_scores[0].T, train_scores[1].T)
if test_scores is not None:
test_covariances = np.cov(test_scores[0].T, test_scores[1].T)
else:
test_covariances = None
return cls.from_covariances(train_covariances, test_covariances)
@classmethod
def from_covariances(cls, train_covariances, test_covariances=None):
return cls(train_covariances, test_covariances)
def plot(self):
self._validate_plot_params()
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
sns.heatmap(
self.train_covariances,
annot=True,
ax=axs[0],
)
if self.test_covariances is not None:
sns.heatmap(
self.test_covariances,
annot=True,
ax=axs[1],
)
axs[0].set_title("Train Covariances")
axs[1].set_title("Test Covariances")
# plt.tight_layout()
self.figure_ = fig
return self