Browse Source

Update get_mnist_add.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
bc5fd2342f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 21 deletions
  1. +8
    -21
      datasets/mnist_add/get_mnist_add.py

+ 8
- 21
datasets/mnist_add/get_mnist_add.py View File

@@ -3,35 +3,22 @@ import torchvision
from torch.utils.data import Dataset
from torchvision.transforms import transforms

class MNIST_Addition(Dataset):
def __init__(self, dataset, examples):
self.data = list()
self.dataset = dataset
with open(examples) as f:
for line in f:
line = line.strip().split(' ')
self.data.append(tuple([int(i) for i in line]))

def __len__(self):
return len(self.data)

def __getitem__(self, index):
i1, i2, l = self.data[index]
return self.dataset[i1][0], self.dataset[i2][0], l

def get_mnist_add():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
train_dataset = MNIST_Addition(torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform), './train_data.txt')
train_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./', train=False, transform=transform), batch_size=1000, shuffle=True)
X = []
Y = []
for i1, i2, l in train_dataset:
X.append([i1, i2])
Y.append(l)
with open('./train_data.txt') as f:
for line in f:
line = line.strip().split(' ')
X.append((train_dataset[int(line[0])][0], train_dataset[int(line[1])][0]))
Y.append(int(line[2]))
return X, Y, test_loader

if __name__ == "__main__":
X, Y, test_loader = get_mnist_add()
print(len(X), len(Y))
print(X[0][0].shape, X[0][1].shape, Y[0])

Loading…
Cancel
Save