Browse Source

Update get_mnist_add.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
8b5a2ce2a5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 10 deletions
  1. +18
    -10
      datasets/mnist_add/get_mnist_add.py

+ 18
- 10
datasets/mnist_add/get_mnist_add.py View File

@@ -5,20 +5,28 @@ from torchvision.transforms import transforms

def get_mnist_add():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
train_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./', train=False, transform=transform), batch_size=1000, shuffle=True)
img_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform)
X = []
Y = []
train_X = []
train_Y = []
with open('./train_data.txt') as f:
for line in f:
line = line.strip().split(' ')
X.append((train_dataset[int(line[0])][0], train_dataset[int(line[1])][0]))
Y.append(int(line[2]))
train_X.append((img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]))
train_Y.append(int(line[2]))
return X, Y, test_loader
test_X = []
test_Y = []
with open('./test_data.txt') as f:
for line in f:
line = line.strip().split(' ')
test_X.append((img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]))
test_Y.append(int(line[2]))
return train_X, train_Y, test_X, test_Y

if __name__ == "__main__":
X, Y, test_loader = get_mnist_add()
print(len(X), len(Y))
print(X[0][0].shape, X[0][1].shape, Y[0])
train_X, train_Y, test_X, test_Y = get_mnist_add()
print(len(train_X), len(test_X))
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0])

Loading…
Cancel
Save