From f9ea562a3f86fa39493f4be31c50debdac7e8c21 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:26:39 +0800 Subject: [PATCH 1/8] Update abducer_base.py --- abducer/abducer_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 54380ec..284ebe4 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -11,6 +11,7 @@ #================================================================# import abc +# from kb import add_KB, hwf_KB from abducer.kb import add_KB, hwf_KB import numpy as np From 778ae172dfe1972a8d5fa75e149aff0ea2d530e0 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:27:15 +0800 Subject: [PATCH 2/8] Update get_mnist_add.py --- datasets/mnist_add/get_mnist_add.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index 1af834a..871e503 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -23,9 +23,9 @@ def get_data(file, img_dataset, get_pseudo_label): def get_mnist_add(train = True, get_pseudo_label = False): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform) + img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform) - if(train): + if train: file = './datasets/mnist_add/train_data.txt' else: file = './datasets/mnist_add/test_data.txt' From af00136661a327c031af7080e23615dceaa533a1 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:27:49 +0800 Subject: [PATCH 3/8] Update get_hwf.py --- datasets/hwf/get_hwf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datasets/hwf/get_hwf.py b/datasets/hwf/get_hwf.py index f94f1a5..543e0d5 100644 --- a/datasets/hwf/get_hwf.py +++ b/datasets/hwf/get_hwf.py @@ -24,11 +24,10 @@ def get_data(file, get_pseudo_label, precision_num = 2): imgs.append(img) if(get_pseudo_label): imgs_pseudo_label.append(img_path.split('/')[0]) - if(len(imgs) == 3): - X.append(imgs) - if(get_pseudo_label): - Z.append(imgs_pseudo_label) - Y.append(round(data[idx]['res'], precision_num)) + X.append(imgs) + if(get_pseudo_label): + Z.append(imgs_pseudo_label) + Y.append(round(data[idx]['res'], precision_num)) if(get_pseudo_label): return X, Z, Y From 740df944c1c09aca25aa50997c453601e6e63070 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:28:30 +0800 Subject: [PATCH 4/8] Update get_hwf.py --- datasets/hwf/get_hwf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datasets/hwf/get_hwf.py b/datasets/hwf/get_hwf.py index 543e0d5..084ed16 100644 --- a/datasets/hwf/get_hwf.py +++ b/datasets/hwf/get_hwf.py @@ -22,14 +22,14 @@ def get_data(file, get_pseudo_label, precision_num = 2): img = Image.open(img_dir + img_path).convert('L') img = img_transform(img) imgs.append(img) - if(get_pseudo_label): + if get_pseudo_label: imgs_pseudo_label.append(img_path.split('/')[0]) X.append(imgs) - if(get_pseudo_label): + if get_pseudo_label: Z.append(imgs_pseudo_label) Y.append(round(data[idx]['res'], precision_num)) - if(get_pseudo_label): + if get_pseudo_label: return X, Z, Y else: return X, None, Y From 1258d83edc3538376474b31781284cf585ee3c14 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:29:05 +0800 Subject: [PATCH 5/8] Update get_hwf.py --- datasets/hwf/get_hwf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/hwf/get_hwf.py b/datasets/hwf/get_hwf.py index 084ed16..b478a15 100644 --- a/datasets/hwf/get_hwf.py +++ b/datasets/hwf/get_hwf.py @@ -9,7 +9,7 @@ img_transform = transforms.Compose([ def get_data(file, get_pseudo_label, precision_num = 2): X = [] - if(get_pseudo_label): + if get_pseudo_label: Z = [] Y = [] img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/' From 44b140c1fd7c5d158c0d818f5943083c688369e9 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:29:29 +0800 Subject: [PATCH 6/8] Update get_mnist_add.py --- datasets/mnist_add/get_mnist_add.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index 871e503..d00c6ce 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -5,18 +5,18 @@ from torchvision.transforms import transforms def get_data(file, img_dataset, get_pseudo_label): X = [] - if(get_pseudo_label): + if get_pseudo_label: Z = [] Y = [] with open(file) as f: for line in f: line = line.strip().split(' ') X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) - if(get_pseudo_label): + if get_pseudo_label: Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) Y.append(int(line[2])) - if(get_pseudo_label): + if get_pseudo_label: return X, Z, Y else: return X, None, Y From ef0298e91a44c1a9502eb50407b2c9d8987f805f Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 12:34:37 +0800 Subject: [PATCH 7/8] Update kb.py --- abducer/kb.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 8c3f49a..e987d15 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -34,10 +34,6 @@ class KBBase(ABC): def logic_forward(self): pass - @abstractmethod - def valid_candidate(self): - pass - def _length(self, length): if length is None: length = list(self.base.keys()) @@ -57,6 +53,7 @@ class ClsKB(KBBase): self.len_list = len_list if GKB_flag: + # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() self.base = {} X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) for x, y in zip(X, Y): @@ -70,14 +67,12 @@ class ClsKB(KBBase): X = [] Y = [] for x in all_X: - if self.valid_candidate(x): + y = self.logic_forward(x) + if y != np.inf: X.append(x) - Y.append(self.logic_forward(x)) + Y.append(y) return X, Y - def valid_candidate(self): - pass - def logic_forward(self): pass @@ -88,7 +83,7 @@ class ClsKB(KBBase): if key is None: return self.get_all_candidates() - if (type(length) is int and length not in self.len_list): + if type(length) is int and length not in self.len_list: return [] length = self._length(length) return sum([self.base[l][key] for l in length], []) @@ -117,9 +112,6 @@ class add_KB(ClsKB): pseudo_label_list = list(range(10)), \ len_list = [2]): super().__init__(GKB_flag, pseudo_label_list, len_list) - - def valid_candidate(self, x): - return True def logic_forward(self, nums): return sum(nums) From f78c69ae105bf92fda477c13b88de5e76249f0d0 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 13:48:50 +0800 Subject: [PATCH 8/8] Create lenet5.py --- models/lenet5.py | 60 ++++++++++++++++++------------------------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/models/lenet5.py b/models/lenet5.py index 56c1ca6..1016d2c 100644 --- a/models/lenet5.py +++ b/models/lenet5.py @@ -52,48 +52,32 @@ class LeNet5(nn.Module): return x def num_flat_features(self, x): - #x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size - size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字 + size = x.size()[1:] num_features = 1 for s in size: num_features *= s return num_features -class Params: - imgH = 28 - imgW = 28 - keep_ratio = True - saveInterval = 10 - batchSize = 16 - num_workers = 16 -def get_data(): #数据预处理 - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5), (0.5))]) - #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - #训练集 - train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) - train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16) - #测试集 - test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True) - test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16) - classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') - - return train_loader, test_loader, classes - -if __name__ == "__main__": - recorder = plog.ResultRecorder() - cls = LeNet5() - criterion = nn.CrossEntropyLoss(size_average=True) - optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder) - - train_loader, test_loader, classes = get_data() - - #model.val(test_loader, print_prefix = "before training") - model.fit(train_loader, n_epoch = 100) - model.val(test_loader, print_prefix = "after trained") - res = model.predict(test_loader, print_prefix = "predict") - print(res.argmax(axis=1)[:10]) +class SymbolNet(nn.Module): + def __init__(self, num_classes=14): + super(SymbolNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(30976, 128) + self.fc2 = nn.Linear(128, num_classes) + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + return x