"""Helped by https://github.com/bcdutton/AdversarialCanonicalCorrelationAnalysis (hopefully I will have my own
implementation of their work soon) Check out their paper at https://arxiv.org/abs/2005.10349 """
import numpy as np
import torch
import torch.utils.data
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets, transforms
[docs]class Split_MNIST_Dataset(Dataset):
"""
Class to generate paired noisy mnist data
"""
def __init__(
self, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
):
"""
:param mnist_type: "MNIST", "FashionMNIST" or "KMNIST"
:param train: whether this is train or test
:param flatten: whether to flatten the data into array or use 2d images
"""
if mnist_type == "MNIST":
self.dataset = datasets.MNIST("../../data", train=train, download=True)
elif mnist_type == "FashionMNIST":
self.dataset = datasets.FashionMNIST(
"../../data", train=train, download=True
)
elif mnist_type == "KMNIST":
self.dataset = datasets.KMNIST("../../data", train=train, download=True)
self.data = self.dataset.data
self.targets = self.dataset.targets
self.flatten = flatten
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx].flatten()
x_a = x[:392] / 255
x_b = x[392:] / 255
label = self.targets[idx]
return (x_a, x_b), label
[docs] def to_numpy(self, indices=None):
"""
Converts dataset to numpy array form
:param indices: indices of the samples to extract into numpy arrays
"""
if indices is None:
indices = np.arange(self.__len__())
view_1 = np.zeros((len(indices), 392))
view_2 = np.zeros((len(indices), 392))
labels = np.zeros(len(indices)).astype(int)
for i, n in enumerate(indices):
sample = self[n]
view_1[i] = sample[0][0].numpy()
view_2[i] = sample[0][1].numpy()
labels[i] = sample[1].numpy().astype(int)
return (view_1, view_2), labels
[docs]class Noisy_MNIST_Dataset(Dataset):
"""
Class to generate paired noisy mnist data
"""
def __init__(
self, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
):
"""
:param mnist_type: "MNIST", "FashionMNIST" or "KMNIST"
:param train: whether this is train or test
:param flatten: whether to flatten the data into array or use 2d images
"""
if mnist_type == "MNIST":
self.dataset = datasets.MNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "FashionMNIST":
self.dataset = datasets.FashionMNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "KMNIST":
self.dataset = datasets.KMNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
self.base_transform = transforms.ToTensor()
self.a_transform = transforms.Compose(
[torchvision.transforms.RandomRotation((-45, 45))]
)
self.a_transform = transforms.Compose(
[torchvision.transforms.RandomRotation((-45, 45))]
)
self.b_transform = transforms.Compose(
[
transforms.Lambda(_add_mnist_noise),
transforms.Lambda(self.__threshold_func__),
]
)
self.targets = self.dataset.targets
self.filtered_classes = []
self.filtered_nums = []
for i in range(10):
self.filtered_nums.append(np.where(self.targets == i)[0])
self.flatten = flatten
def __threshold_func__(self, x):
x[x > 1] = 1
return x
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
x_a, label = self.dataset[idx]
x_a = self.a_transform(x_a)
# get random index of image with same class
random_index = np.random.choice(self.filtered_nums[label])
x_b = self.b_transform(self.dataset[random_index][0])
if self.flatten:
x_a = torch.flatten(x_a)
x_b = torch.flatten(x_b)
return (x_b, x_a), label
[docs]class Tangled_MNIST_Dataset(Dataset):
"""
Class to generate paired tangled MNIST dataset
"""
def __init__(self, mnist_type="MNIST", train=True, flatten=True):
"""
:param mnist_type: "MNIST", "FashionMNIST" or "KMNIST"
:param train: whether this is train or test
:param flatten: whether to flatten the data into array or use 2d images
"""
if mnist_type == "MNIST":
self.dataset = datasets.MNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "FashionMNIST":
self.dataset = datasets.FashionMNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "KMNIST":
self.dataset = datasets.KMNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
self.transform = transforms.Compose(
[torchvision.transforms.RandomRotation((-45, 45))]
)
self.targets = self.dataset.targets
self.filtered_classes = []
self.filtered_nums = []
for i in range(10):
self.filtered_nums.append(np.where(self.targets == i)[0])
self.flatten = flatten
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
x_a, label = self.dataset[idx]
x_a = self.transform(x_a)
# get random index of image with same class
random_index = np.random.choice(self.filtered_nums[label])
x_b = self.transform(self.dataset[random_index][0])
if self.flatten:
x_a = torch.flatten(x_a)
x_b = torch.flatten(x_b)
return (x_b, x_a), label
def _add_mnist_noise(x):
x = x + torch.rand(28, 28) / 10
return x