diff --git a/examples/hwf/datasets/get_hwf.py b/examples/hwf/datasets/get_hwf.py index ef344bd..da63f70 100644 --- a/examples/hwf/datasets/get_hwf.py +++ b/examples/hwf/datasets/get_hwf.py @@ -1,10 +1,10 @@ -import os import json +import os.path as osp from PIL import Image from torchvision.transforms import transforms -CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) +CURRENT_DIR = osp.abspath(osp.dirname(__file__)) img_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (1,))] @@ -15,7 +15,7 @@ def get_data(file, get_pseudo_label): X, Y = [], [] if get_pseudo_label: Z = [] - img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") + img_dir = osp.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") with open(file) as f: data = json.load(f) for idx in range(len(data)): @@ -40,8 +40,8 @@ def get_data(file, get_pseudo_label): def get_hwf(train=True, get_gt_pseudo_label=False): if train: - file = os.path.join(CURRENT_DIR, "data/expr_train.json") + file = osp.join(CURRENT_DIR, "data/expr_train.json") else: - file = os.path.join(CURRENT_DIR, "data/expr_test.json") + file = osp.join(CURRENT_DIR, "data/expr_test.json") return get_data(file, get_gt_pseudo_label) diff --git a/examples/mnist_add/datasets/get_mnist_add.py b/examples/mnist_add/datasets/get_mnist_add.py index 21b9101..132c4b5 100644 --- a/examples/mnist_add/datasets/get_mnist_add.py +++ b/examples/mnist_add/datasets/get_mnist_add.py @@ -1,40 +1,49 @@ +import os.path as osp + import torchvision from torchvision.transforms import transforms +CURRENT_DIR = osp.abspath(osp.dirname(__file__)) + def get_data(file, img_dataset, get_pseudo_label): - X = [] + X, Y = [], [] if get_pseudo_label: Z = [] - Y = [] with open(file) as f: for line in f: - line = line.strip().split(' ') + # if len(X) == 1000: + # break + line = line.strip().split(" ") X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) if get_pseudo_label: Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) Y.append(int(line[2])) - + if get_pseudo_label: return X, Z, Y else: return X, None, Y -def get_mnist_add(train = True, get_pseudo_label = False): - transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform) - + +def get_mnist_add(train=True, get_pseudo_label=False): + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + img_dataset = torchvision.datasets.MNIST( + root=CURRENT_DIR, train=train, download=True, transform=transform + ) + if train: - file = './datasets/train_data.txt' + file = osp.join(CURRENT_DIR, "train_data.txt") else: - file = './datasets/test_data.txt' - + file = osp.join(CURRENT_DIR, "test_data.txt") + return get_data(file, img_dataset, get_pseudo_label) - + if __name__ == "__main__": - train_X, train_Y = get_mnist_add(train = True) - test_X, test_Y = get_mnist_add(train = False) + train_X, train_Z, train_Y = get_mnist_add(train=True) + test_X, test_Z, test_Y = get_mnist_add(train=False) print(len(train_X), len(test_X)) print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) -