Source code for cca_zoo.visualisation.tsne_scores

import seaborn as sns

from cca_zoo._utils._checks import check_tsne_support, check_seaborn_support
from cca_zoo.visualisation import ScoreScatterDisplay


[docs] class TSNEScoreDisplay(ScoreScatterDisplay): def _validate_plot_params(self): check_tsne_support("TSNEScoreDisplay") check_seaborn_support("TSNEScoreDisplay") def plot(self): self._validate_plot_params() import openTSNE import matplotlib.pyplot as plt reducer = openTSNE.TSNE() embedding = reducer.fit(self.scores[0]) fig, ax = plt.subplots() sns.scatterplot( x=embedding[:, 0], y=embedding[:, 1], hue=self.labels, ax=ax, alpha=0.1 if self.test_scores is not None else 1.0, label="Train" if self.test_scores is not None else None, **self.kwargs, ) if self.test_scores is not None: embedding = reducer.fit(self.test_scores[0]) sns.scatterplot( x=embedding[:, 0], y=embedding[:, 1], hue=self.test_labels, ax=ax, label="Test", **self.kwargs, ) plt.tight_layout() self.figure_ = fig return self