Source code for cca_zoo.visualisation.umap_scores
import seaborn as sns
from cca_zoo._utils._checks import check_umap_support, check_seaborn_support
from cca_zoo.visualisation import ScoreScatterDisplay
[docs]
class UMAPScoreDisplay(ScoreScatterDisplay):
def _validate_plot_params(self):
check_umap_support("UMAPScoreDisplay")
check_seaborn_support("TSNEScoreDisplay")
def plot(self, **kwargs):
self._validate_plot_params()
import umap
import matplotlib.pyplot as plt
reducer = umap.UMAP()
embedding = reducer.fit_transform(self.scores[0])
fig, ax = plt.subplots()
sns.scatterplot(
x=embedding[:, 0],
y=embedding[:, 1],
hue=self.labels,
ax=ax,
alpha=0.1 if self.test_scores is not None else 1.0,
label="Train" if self.test_scores is not None else None,
**self.kwargs,
)
if self.test_scores is not None:
embedding = reducer.transform(self.test_scores[0])
sns.scatterplot(
x=embedding[:, 0],
y=embedding[:, 1],
hue=self.test_labels,
ax=ax,
label="Test",
**self.kwargs,
)
plt.tight_layout()
self.figure_ = fig
return self