Data

Contents

Simulated Data

cca_zoo.data.simulated.simple_simulated_data(n: int, view_features: List[int], view_sparsity: Optional[List[Union[int, float]]] = None, eps: float = 0, transform=False, random_state=None)[source]

Generate a simple simulated dataset with a single latent dimension

Parameters
  • n (int) – Number of samples

  • view_features – Number of features in each view

  • view_sparsity (List[Union[int, float]]) – Sparsity of each view. If int, then number of non-zero features. If float, then proportion of non-zero features.

  • eps (float) – Noise level

  • transform (bool) – Whether to transform the data to be non-negative

  • random_state (int, RandomState instance, default=None) – Controls the random seed in generating the data.

Deep

class cca_zoo.data.deep.NumpyDataset(views, labels=None)[source]

Class that turns numpy arrays into a torch dataset

Parameters

views – list/tuple of numpy arrays or array likes with the same number of rows (samples)

cca_zoo.data.deep.check_dataset(dataset)[source]

Checks that a custom dataset returns a dictionary with a “views” key containing a list of tensors

Parameters

dataset (torch.utils.data.Dataset) –

cca_zoo.data.deep.get_dataloaders(dataset, val_dataset=None, batch_size=None, val_batch_size=None, drop_last=True, val_drop_last=False, shuffle_train=False, pin_memory=True, num_workers=0, persistent_workers=True)[source]

A utility function to allow users to quickly get hold of the dataloaders required by pytorch lightning

Parameters
  • dataset – A CCA dataset used for training

  • val_dataset – An optional CCA dataset used for validation

  • batch_size – batch size of train loader

  • val_batch_size – batch size of val loader

  • num_workers – number of workers used

  • pin_memory – pin memory used by pytorch - True tends to speed up training

  • shuffle_train – whether to shuffle training data

  • val_drop_last – whether to drop the last incomplete batch from the validation data

  • drop_last – whether to drop the last incomplete batch from the train data

  • persistent_workers – whether to keep workers alive after dataloader is destroyed