|
|
|
@@ -1,6 +1,4 @@ |
|
|
|
import torch |
|
|
|
import torchvision |
|
|
|
from torch.utils.data import Dataset |
|
|
|
from torchvision.transforms import transforms |
|
|
|
|
|
|
|
def get_data(file, img_dataset, get_pseudo_label): |
|
|
|
@@ -23,12 +21,12 @@ def get_data(file, img_dataset, get_pseudo_label): |
|
|
|
|
|
|
|
def get_mnist_add(train = True, get_pseudo_label = False): |
|
|
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) |
|
|
|
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform) |
|
|
|
img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform) |
|
|
|
|
|
|
|
if train: |
|
|
|
file = './datasets/mnist_add/train_data.txt' |
|
|
|
file = './datasets/train_data.txt' |
|
|
|
else: |
|
|
|
file = './datasets/mnist_add/test_data.txt' |
|
|
|
file = './datasets/test_data.txt' |
|
|
|
|
|
|
|
return get_data(file, img_dataset, get_pseudo_label) |
|
|
|
|
|
|
|
|