Browse Source

Update get_mnist_add.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
778ae172df
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      datasets/mnist_add/get_mnist_add.py

+ 2
- 2
datasets/mnist_add/get_mnist_add.py View File

@@ -23,9 +23,9 @@ def get_data(file, img_dataset, get_pseudo_label):

def get_mnist_add(train = True, get_pseudo_label = False):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform)
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform)
if(train):
if train:
file = './datasets/mnist_add/train_data.txt'
else:
file = './datasets/mnist_add/test_data.txt'


Loading…
Cancel
Save