From bc5fd2342f86ed24c3e944f66ea9c64ec358a168 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 15 Nov 2022 14:47:24 +0800 Subject: [PATCH] Update get_mnist_add.py --- datasets/mnist_add/get_mnist_add.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index d05f477..167ee57 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -3,35 +3,22 @@ import torchvision from torch.utils.data import Dataset from torchvision.transforms import transforms -class MNIST_Addition(Dataset): - def __init__(self, dataset, examples): - self.data = list() - self.dataset = dataset - with open(examples) as f: - for line in f: - line = line.strip().split(' ') - self.data.append(tuple([int(i) for i in line])) - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - i1, i2, l = self.data[index] - return self.dataset[i1][0], self.dataset[i2][0], l - def get_mnist_add(): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - train_dataset = MNIST_Addition(torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform), './train_data.txt') + train_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./', train=False, transform=transform), batch_size=1000, shuffle=True) + X = [] Y = [] - for i1, i2, l in train_dataset: - X.append([i1, i2]) - Y.append(l) + with open('./train_data.txt') as f: + for line in f: + line = line.strip().split(' ') + X.append((train_dataset[int(line[0])][0], train_dataset[int(line[1])][0])) + Y.append(int(line[2])) + return X, Y, test_loader if __name__ == "__main__": X, Y, test_loader = get_mnist_add() print(len(X), len(Y)) print(X[0][0].shape, X[0][1].shape, Y[0]) - \ No newline at end of file