From 8b5a2ce2a53cbb2bda24817266ded05434927f7d Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 16 Nov 2022 13:07:03 +0800 Subject: [PATCH] Update get_mnist_add.py --- datasets/mnist_add/get_mnist_add.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index 167ee57..0c9a273 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -5,20 +5,28 @@ from torchvision.transforms import transforms def get_mnist_add(): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - train_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform) - test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./', train=False, transform=transform), batch_size=1000, shuffle=True) + img_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform) - X = [] - Y = [] + train_X = [] + train_Y = [] with open('./train_data.txt') as f: for line in f: line = line.strip().split(' ') - X.append((train_dataset[int(line[0])][0], train_dataset[int(line[1])][0])) - Y.append(int(line[2])) + train_X.append((img_dataset[int(line[0])][0], img_dataset[int(line[1])][0])) + train_Y.append(int(line[2])) - return X, Y, test_loader + test_X = [] + test_Y = [] + with open('./test_data.txt') as f: + for line in f: + line = line.strip().split(' ') + test_X.append((img_dataset[int(line[0])][0], img_dataset[int(line[1])][0])) + test_Y.append(int(line[2])) + + return train_X, train_Y, test_X, test_Y if __name__ == "__main__": - X, Y, test_loader = get_mnist_add() - print(len(X), len(Y)) - print(X[0][0].shape, X[0][1].shape, Y[0]) + train_X, train_Y, test_X, test_Y = get_mnist_add() + print(len(train_X), len(test_X)) + print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) +