Browse Source

Update get_mnist_add.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
3616608aaf
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      datasets/mnist_add/get_mnist_add.py

+ 3
- 3
datasets/mnist_add/get_mnist_add.py View File

@@ -15,10 +15,10 @@ def get_data(file, img_dataset):

def get_mnist_add():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
img_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform)
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform)
train_X, train_Y = get_data('./train_data.txt', img_dataset)
test_X, test_Y = get_data('./test_data.txt', img_dataset)
train_X, train_Y = get_data('./datasets/mnist_add/train_data.txt', img_dataset)
test_X, test_Y = get_data('./datasets/mnist_add/test_data.txt', img_dataset)
return train_X, train_Y, test_X, test_Y



Loading…
Cancel
Save