Source code for cca_zoo.probabilistic._rcca

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

from cca_zoo.probabilistic._cca import ProbabilisticCCA


[docs] class ProbabilisticRCCA(ProbabilisticCCA): """ Probabilistic Ridge Canonical Correlation Analysis (Probabilistic Ridge CCA). Probabilistic Ridge CCA extends the Probabilistic Canonical Correlation Analysis model by introducing regularization terms in the linear relationships between multiple representations of data. This regularization improves the conditioning of the problem and provides a way to incorporate prior knowledge. It combines features of both CCA and Ridge Regression. Parameters ---------- c: float, default=1.0 Regularization strength; must be a positive float. Regularization improves the conditioning of the problem and reduces the variance of the estimates. Larger values specify stronger regularization. References ---------- [1] De Bie, T. and De Moor, B., 2003. On the regularization of canonical correlation analysis. Int. Sympos. ICA and BSS, pp.785-790. """ def _model(self, views): """ Defines the generative model for Probabilistic CCA. Parameters ---------- views: tuple of np.ndarray A tuple containing the first and second representations, X1 and X2, each as a numpy array. """ X1, X2 = views W1 = numpyro.param( "W_1", jnp.ones( shape=( self.n_features_in_[0], self.latent_dimensions, ), ), ) W2 = numpyro.param( "W_2", jnp.ones( shape=( self.n_features_in_[1], self.latent_dimensions, ), ), ) sigma1 = numpyro.param( "sigma_1", jnp.ones(1), constraint=dist.constraints.positive ) sigma2 = numpyro.param( "sigma_2", jnp.ones(1), constraint=dist.constraints.positive ) psi1 = jnp.eye(self.n_features_in_[0]) * sigma1 psi2 = jnp.eye(self.n_features_in_[1]) * sigma2 mu1 = numpyro.param( "mu_1", jnp.zeros( shape=( 1, self.n_features_in_[0], ), ), ) mu2 = numpyro.param( "mu_2", jnp.zeros( shape=( 1, self.n_features_in_[1], ), ), ) with numpyro.plate("n", self.n_samples_): z = numpyro.sample( "z", dist.MultivariateNormal( jnp.zeros(self.latent_dimensions), jnp.eye(self.latent_dimensions) ), ) numpyro.sample( "X1", dist.MultivariateNormal( jnp.outer(z, W1.T) + mu1, covariance_matrix=psi1, ), obs=X1, ) numpyro.sample( "X2", dist.MultivariateNormal( jnp.outer(z, W2.T) + mu2, covariance_matrix=psi2, ), obs=X2, ) def _guide(self, views): """ Defines the variational distribution for Probabilistic CCA. Parameters ---------- views: tuple of np.ndarray A tuple containing the first and second representations, X1 and X2, each as a numpy array. """ # Variational parameters for the approximate posterior of z z_loc = numpyro.param( "z_loc", jnp.zeros((self.n_samples_, self.latent_dimensions)) ) z_scale = numpyro.param( "z_scale", jnp.ones((self.n_samples_, self.latent_dimensions)), constraint=dist.constraints.positive, ) with numpyro.plate("n", self.n_samples_): numpyro.sample("z", dist.MultivariateNormal(z_loc, jnp.diag(z_scale))) def joint(self): psi1 = jnp.eye(self.n_features_in_[0]) * self.params["sigma_1"] psi2 = jnp.eye(self.n_features_in_[1]) * self.params["sigma_2"] # Calculate the individual matrix blocks top_left = self.params["W_1"] @ self.params["W_1"].T + psi1 bottom_right = self.params["W_2"] @ self.params["W_2"].T + psi2 top_right = self.params["W_1"] @ self.params["W_2"].T bottom_left = self.params["W_2"] @ self.params["W_1"].T # Construct the matrix using the blocks matrix = np.block([[top_left, top_right], [bottom_left, bottom_right]]) return matrix