Source code for cca_zoo.visualisation.inference

from cca_zoo._utils._checks import check_arviz_support


[docs] class WeightInferenceDisplay: """ Class for displaying inference-related plots. Attributes ---------- idata : arviz.InferenceData The posterior samples. true_features: array-like, optional The true features for comparison in the plot, defaults to None. num_views: int, optional The number of representations, defaults to 2. """ def __init__(self, idata, num_views=2, true_features=None): """ Initialize the WeightInferenceDisplay object. Parameters ---------- idata : arviz.InferenceData The posterior samples. num_views : int, optional The number of representations, defaults to 2. true_features : array-like, optional The true features for comparison in the plot, defaults to None. """ self.idata = idata self.true_features = true_features self.num_views = num_views def _validate_plot_params(self): """ Internal method to validate plotting parameters. Currently, it checks if arviz is supported. """ check_arviz_support("CorrelationHeatmapDisplay")
[docs] @classmethod def from_estimator(cls, pcca_estimator, true_features=None): """ Class method to create an InferenceDisplay instance from an estimator. Parameters ---------- pcca_estimator : object The estimator object with an 'mcmc' attribute. true_features : array-like, optional The true features for comparison in the plot, defaults to None. Returns ------- WeightInferenceDisplay An InferenceDisplay instance. """ return cls.from_mcmc(pcca_estimator.mcmc, true_features)
[docs] @classmethod def from_mcmc(cls, mcmc, true_features=None): """ Class method to create an InferenceDisplay instance from mcmc samples. Parameters ---------- mcmc : object The mcmc samples. true_features : array-like, optional The true features for comparison in the plot, defaults to None. Returns ------- WeightInferenceDisplay An InferenceDisplay instance. """ import arviz as az idata = az.from_numpyro(mcmc) return cls(idata, 2, true_features)
[docs] def plot(self): """ Plot the posterior distributions of parameters and latent variables. Adds true values if they are provided. """ import arviz as az import matplotlib.pyplot as plt for view in range(self.num_views): # Plot the posterior distribution of W_0 parameter (for just the first latent variable). # Label the weights_ with their weight index. Make all parameters share x axis. trace_plot = az.plot_trace( self.idata, var_names=[f"W_{view}"], compact=False, divergences=None ) # For each w in W_0, plot the true value from data.true_features[0] for i, ax in enumerate(trace_plot[:, 0]): if self.true_features is not None: ax.axvline( self.true_features[view].ravel()[i], color="red", linestyle="--", label="True Value", ) ax.legend() plt.suptitle(f"Posterior Distribution of W_{view}") plt.tight_layout() plt.show()