Source code for cca_zoo.visualisation.correlation
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from cca_zoo._utils._checks import check_seaborn_support
[docs]
class CorrelationHeatmapDisplay:
"""Correlation Heatmap Display
Heatmap of the correlations between the latent variables of the representations.
Parameters
----------
train_correlations : np.ndarray
The train correlations between representations.
test_correlations : np.ndarray
The test correlations between representations.
Attributes
----------
figure_ : matplotlib.pyplot.figure
The figure of the plot.
Examples
--------
>>> from cca_zoo.visualisation import CorrelationHeatmapDisplay
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from cca_zoo.linear import MCCA
>>>
>>> # Generate Sample Data
>>> # --------------------
>>> X = np.random.rand(100, 10)
>>> Y = np.random.rand(100, 10)
>>>
>>> # Splitting the data into training and testing sets
>>> X_train, X_test = X[:50], X[50:]
>>> Y_train, Y_test = Y[:50], Y[50:]
>>>
>>> representations = [X_train, Y_train]
>>> test_views = [X_test, Y_test]
>>>
>>> # Train an MCCA Model
>>> # -------------------
>>> mcca = MCCA(latent_dimensions=2)
>>> mcca.fit(representations)
>>>
>>> # %%
>>> # Plotting the Correlation Heatmap
>>> # -------------------------------
>>> CorrelationHeatmapDisplay.from_estimator(mcca, representations, test_views=test_views).plot()
>>> plt.show()
"""
def __init__(self, train_correlations, test_correlations):
self.train_correlations = train_correlations
self.test_correlations = test_correlations
def _validate_plot_params(self):
check_seaborn_support("CorrelationHeatmapDisplay")
@classmethod
def from_estimator(cls, model, train_views, test_views=None):
train_scores = model.transform(train_views)
if test_views is not None:
test_scores = model.transform(test_views)
else:
test_scores = None
train_correlations = np.corrcoef(train_scores[0].T, train_scores[1].T)
if test_scores is not None:
test_correlations = np.corrcoef(test_scores[0].T, test_scores[1].T)
else:
test_correlations = None
return cls.from_correlations(train_correlations, test_correlations)
@classmethod
def from_correlations(cls, train_correlations, test_correlations=None):
return cls(train_correlations, test_correlations)
def plot(self):
self._validate_plot_params()
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
sns.heatmap(
self.train_correlations,
annot=True,
cmap="coolwarm",
ax=axs[0],
vmin=-1,
vmax=1,
)
if self.test_correlations is not None:
sns.heatmap(
self.test_correlations,
annot=True,
cmap="coolwarm",
ax=axs[1],
vmin=-1,
vmax=1,
)
axs[0].set_title("Train Correlations")
axs[1].set_title("Test Correlations")
# plt.tight_layout()
self.figure_ = fig
return self