Browse Source

[MNT] add CURRENT_DIR

ab_data
Gao Enhao 2 years ago
parent
commit
7e79dccd6e
2 changed files with 29 additions and 20 deletions
  1. +5
    -5
      examples/hwf/datasets/get_hwf.py
  2. +24
    -15
      examples/mnist_add/datasets/get_mnist_add.py

+ 5
- 5
examples/hwf/datasets/get_hwf.py View File

@@ -1,10 +1,10 @@
import os
import json import json
import os.path as osp


from PIL import Image from PIL import Image
from torchvision.transforms import transforms from torchvision.transforms import transforms


CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
CURRENT_DIR = osp.abspath(osp.dirname(__file__))


img_transform = transforms.Compose( img_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))] [transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]
@@ -15,7 +15,7 @@ def get_data(file, get_pseudo_label):
X, Y = [], [] X, Y = [], []
if get_pseudo_label: if get_pseudo_label:
Z = [] Z = []
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
img_dir = osp.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
with open(file) as f: with open(file) as f:
data = json.load(f) data = json.load(f)
for idx in range(len(data)): for idx in range(len(data)):
@@ -40,8 +40,8 @@ def get_data(file, get_pseudo_label):


def get_hwf(train=True, get_gt_pseudo_label=False): def get_hwf(train=True, get_gt_pseudo_label=False):
if train: if train:
file = os.path.join(CURRENT_DIR, "data/expr_train.json")
file = osp.join(CURRENT_DIR, "data/expr_train.json")
else: else:
file = os.path.join(CURRENT_DIR, "data/expr_test.json")
file = osp.join(CURRENT_DIR, "data/expr_test.json")


return get_data(file, get_gt_pseudo_label) return get_data(file, get_gt_pseudo_label)

+ 24
- 15
examples/mnist_add/datasets/get_mnist_add.py View File

@@ -1,40 +1,49 @@
import os.path as osp

import torchvision import torchvision
from torchvision.transforms import transforms from torchvision.transforms import transforms


CURRENT_DIR = osp.abspath(osp.dirname(__file__))



def get_data(file, img_dataset, get_pseudo_label): def get_data(file, img_dataset, get_pseudo_label):
X = []
X, Y = [], []
if get_pseudo_label: if get_pseudo_label:
Z = [] Z = []
Y = []
with open(file) as f: with open(file) as f:
for line in f: for line in f:
line = line.strip().split(' ')
# if len(X) == 1000:
# break
line = line.strip().split(" ")
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]])
if get_pseudo_label: if get_pseudo_label:
Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]])
Y.append(int(line[2])) Y.append(int(line[2]))
if get_pseudo_label: if get_pseudo_label:
return X, Z, Y return X, Z, Y
else: else:
return X, None, Y return X, None, Y


def get_mnist_add(train = True, get_pseudo_label = False):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform)

def get_mnist_add(train=True, get_pseudo_label=False):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
img_dataset = torchvision.datasets.MNIST(
root=CURRENT_DIR, train=train, download=True, transform=transform
)

if train: if train:
file = './datasets/train_data.txt'
file = osp.join(CURRENT_DIR, "train_data.txt")
else: else:
file = './datasets/test_data.txt'
file = osp.join(CURRENT_DIR, "test_data.txt")
return get_data(file, img_dataset, get_pseudo_label) return get_data(file, img_dataset, get_pseudo_label)


if __name__ == "__main__": if __name__ == "__main__":
train_X, train_Y = get_mnist_add(train = True)
test_X, test_Y = get_mnist_add(train = False)
train_X, train_Z, train_Y = get_mnist_add(train=True)
test_X, test_Z, test_Y = get_mnist_add(train=False)
print(len(train_X), len(test_X)) print(len(train_X), len(test_X))
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0])

Loading…
Cancel
Save