Source code for cca_zoo.visualisation.weights
import matplotlib.pyplot as plt
import seaborn as sns
[docs]
class WeightHeatmapDisplay:
"""Heatmap of the weights of a model.
Parameters
----------
model: CCA model
A fitted CCA model.
"""
def __init__(self, weights, view_labels=None, **kwargs):
self.weights = weights
self.view_labels = view_labels
self.kwargs = kwargs
@classmethod
def from_estimator(cls, model, view_labels=None, **kwargs):
weights = model.weights_
return cls.from_weights(weights, view_labels=view_labels, **kwargs)
@classmethod
def from_weights(cls, weights, view_labels=None, **kwargs):
return cls(weights, view_labels=view_labels, **kwargs)
[docs]
def plot(self, **kwargs):
"""Plot the heatmap.
Parameters
----------
ax: matplotlib axes, optional
Axes to plot on, by default None.
kwargs: dict
Keyword arguments to pass to seaborn.heatmap
Returns
-------
ax: matplotlib axes
Axes with the heatmap.
"""
fig, axs = plt.subplots(1, len(self.weights), figsize=(10, 5))
if self.view_labels is None:
self.view_labels = [f"View {i}" for i in range(len(self.weights))]
self.weights_cov = [w.T @ w for w in self.weights]
# loop through each view and have a heatmap of the covariance of the weights_
for i, view_weights_cov in enumerate(self.weights_cov):
sns.heatmap(view_weights_cov, ax=axs[i], annot=True, **self.kwargs)
axs[i].set_title(self.view_labels[i])
plt.tight_layout()
self.figure_ = fig
return self