From 778ae172dfe1972a8d5fa75e149aff0ea2d530e0 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:27:15 +0800 Subject: [PATCH] Update get_mnist_add.py --- datasets/mnist_add/get_mnist_add.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index 1af834a..871e503 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -23,9 +23,9 @@ def get_data(file, img_dataset, get_pseudo_label): def get_mnist_add(train = True, get_pseudo_label = False): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform) + img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform) - if(train): + if train: file = './datasets/mnist_add/train_data.txt' else: file = './datasets/mnist_add/test_data.txt'