Source code for cca_zoo.probabilistic._cca

from typing import Iterable

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax.random import PRNGKey
from numpyro.infer import SVI

from cca_zoo._base import _BaseModel
from cca_zoo._utils._checks import check_graphviz_support


[docs] class ProbabilisticCCA(_BaseModel): """ A class for performing Maximum Likelihood Estimation (MLE) in Probabilistic Canonical Correlation Analysis (CCA) using variational inference. Probabilistic CCA is a generative model that makes the following assumptions: 1. A latent variable representations exists that influences both representations (X1, X2). 2. Each observed view is generated via its own set of parameters: W (weight matrix), mu (mean), and psi (covariance). The generative model can be described as follows: representations ~ N(0, I) X1|representations ~ N(W1 * representations + mu1, psi1) X2|representations ~ N(W2 * representations + mu2, psi2) Parameters ---------- latent_dimensions: int, optional The dimensionality of the latent space, by default 1. copy_data: bool, optional Whether to copy the data, by default True. random_state: int, optional The seed for the random number generator, by default 0. learning_rate: float, optional The learning rate for the optimizer, by default 1e-3. n_iter: int, optional Number of iterations for optimization, by default 10000. num_samples: int, optional Number of MCMC samples, by default 100. References ---------- [1] Bach, Francis R., and Michael I. Jordan. "A probabilistic interpretation of canonical correlation analysis." (2005). [2] Wang, Chong. "Variational Bayesian approach to canonical correlation analysis." IEEE Transactions on Neural Networks 18.3 (2007): 905-910. """ return_sites = ["z"] def __init__( self, latent_dimensions: int = 1, copy_data=True, random_state: int = 0, learning_rate=1e-1, n_iter=20000, num_samples=5000, num_warmup=5000, ): super().__init__( latent_dimensions=latent_dimensions, copy_data=copy_data, accept_sparse=False, random_state=random_state, ) self.num_samples = num_samples self.num_warmup = num_warmup self.rng_key = PRNGKey(random_state) self.learning_rate = learning_rate self.n_iter = n_iter self.params = None
[docs] def fit(self, views: Iterable[np.ndarray], y=None): """ Infer the parameters and latent variables of the Probabilistic Canonical Correlation Analysis (CCA) model. Parameters ---------- views : Iterable[np.ndarray] A list or tuple of numpy arrays representing different representations of the same samples. Each numpy array must have the same number of rows. y: Any, optional Ignored in this implementation. Returns ------- self : object Returns the instance itself, updated with the inferred parameters and latent variables. Notes ----- - The data in each view should be normalized for optimal performance. """ views = self._validate_data(views) self._check_params() svi = SVI( self._model, self._guide, numpyro.optim.Adam(self.learning_rate), loss=numpyro.infer.Trace_ELBO(), ) self.svi_result = svi.run(self.rng_key, self.n_iter, views) self.params = self.svi_result.params return self
def _model(self, views): """ Defines the generative model for Probabilistic CCA. Parameters ---------- views: tuple of np.ndarray A tuple containing the first and second representations, X1 and X2, each as a numpy array. """ X1, X2 = views W1 = numpyro.param( "W_1", jnp.ones( shape=( self.n_features_in_[0], self.latent_dimensions, ), ), ) W2 = numpyro.param( "W_2", jnp.ones( shape=( self.n_features_in_[1], self.latent_dimensions, ), ), ) # Add positive-definite constraint for psi1 and psi2 L1 = numpyro.param( "L_1", jnp.eye(self.n_features_in_[0]), constraint=dist.constraints.lower_cholesky, ) psi1 = L1 @ L1.T L2 = numpyro.param( "L_2", jnp.eye(self.n_features_in_[1]), constraint=dist.constraints.lower_cholesky, ) psi2 = L2 @ L2.T mu1 = numpyro.param( "mu_1", jnp.zeros( shape=( 1, self.n_features_in_[0], ), ), ) mu2 = numpyro.param( "mu_2", jnp.zeros( shape=( 1, self.n_features_in_[1], ), ), ) with numpyro.plate("n", self.n_samples_): z = numpyro.sample( "z", dist.MultivariateNormal( jnp.zeros(self.latent_dimensions), jnp.eye(self.latent_dimensions) ), ) numpyro.sample( "X1", dist.MultivariateNormal( jnp.outer(z, W1.T) + mu1, covariance_matrix=psi1, ), obs=X1, ) numpyro.sample( "X2", dist.MultivariateNormal( jnp.outer(z, W2.T) + mu2, covariance_matrix=psi2, ), obs=X2, ) def _guide(self, views): """ Defines the variational distribution for Probabilistic CCA. Parameters ---------- views: tuple of np.ndarray A tuple containing the first and second representations, X1 and X2, each as a numpy array. """ # Variational parameters for the approximate posterior of z z_loc = numpyro.param( "z_loc", jnp.zeros((self.n_samples_, self.latent_dimensions)) ) z_scale = numpyro.param( "z_scale", jnp.ones((self.n_samples_, self.latent_dimensions)), constraint=dist.constraints.positive, ) with numpyro.plate("n", self.n_samples_): numpyro.sample("z", dist.MultivariateNormal(z_loc, jnp.diag(z_scale))) def render(self, views): check_graphviz_support("ProbabilisticCCA") self.rendering = numpyro.render_model( self._model, model_args=(views,), filename="model.pdf" ) def _more_tags(self): return {"probabilistic": True} def joint(self): psi1 = self.params["L_1"] @ self.params["L_1"].T psi2 = self.params["L_2"] @ self.params["L_2"].T # Calculate the individual matrix blocks top_left = self.params["W_1"] @ self.params["W_1"].T + psi1 bottom_right = self.params["W_2"] @ self.params["W_2"].T + psi2 top_right = self.params["W_1"] @ self.params["W_2"].T bottom_left = self.params["W_2"] @ self.params["W_1"].T # Construct the matrix using the blocks matrix = np.block([[top_left, top_right], [bottom_left, bottom_right]]) return matrix