diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 54380ec..284ebe4 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -11,6 +11,7 @@ #================================================================# import abc +# from kb import add_KB, hwf_KB from abducer.kb import add_KB, hwf_KB import numpy as np diff --git a/abducer/kb.py b/abducer/kb.py index 8c3f49a..e987d15 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -34,10 +34,6 @@ class KBBase(ABC): def logic_forward(self): pass - @abstractmethod - def valid_candidate(self): - pass - def _length(self, length): if length is None: length = list(self.base.keys()) @@ -57,6 +53,7 @@ class ClsKB(KBBase): self.len_list = len_list if GKB_flag: + # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() self.base = {} X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) for x, y in zip(X, Y): @@ -70,14 +67,12 @@ class ClsKB(KBBase): X = [] Y = [] for x in all_X: - if self.valid_candidate(x): + y = self.logic_forward(x) + if y != np.inf: X.append(x) - Y.append(self.logic_forward(x)) + Y.append(y) return X, Y - def valid_candidate(self): - pass - def logic_forward(self): pass @@ -88,7 +83,7 @@ class ClsKB(KBBase): if key is None: return self.get_all_candidates() - if (type(length) is int and length not in self.len_list): + if type(length) is int and length not in self.len_list: return [] length = self._length(length) return sum([self.base[l][key] for l in length], []) @@ -117,9 +112,6 @@ class add_KB(ClsKB): pseudo_label_list = list(range(10)), \ len_list = [2]): super().__init__(GKB_flag, pseudo_label_list, len_list) - - def valid_candidate(self, x): - return True def logic_forward(self, nums): return sum(nums) diff --git a/datasets/hwf/get_hwf.py b/datasets/hwf/get_hwf.py index f94f1a5..b478a15 100644 --- a/datasets/hwf/get_hwf.py +++ b/datasets/hwf/get_hwf.py @@ -9,7 +9,7 @@ img_transform = transforms.Compose([ def get_data(file, get_pseudo_label, precision_num = 2): X = [] - if(get_pseudo_label): + if get_pseudo_label: Z = [] Y = [] img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/' @@ -22,15 +22,14 @@ def get_data(file, get_pseudo_label, precision_num = 2): img = Image.open(img_dir + img_path).convert('L') img = img_transform(img) imgs.append(img) - if(get_pseudo_label): + if get_pseudo_label: imgs_pseudo_label.append(img_path.split('/')[0]) - if(len(imgs) == 3): - X.append(imgs) - if(get_pseudo_label): - Z.append(imgs_pseudo_label) - Y.append(round(data[idx]['res'], precision_num)) + X.append(imgs) + if get_pseudo_label: + Z.append(imgs_pseudo_label) + Y.append(round(data[idx]['res'], precision_num)) - if(get_pseudo_label): + if get_pseudo_label: return X, Z, Y else: return X, None, Y diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index 1af834a..d00c6ce 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -5,27 +5,27 @@ from torchvision.transforms import transforms def get_data(file, img_dataset, get_pseudo_label): X = [] - if(get_pseudo_label): + if get_pseudo_label: Z = [] Y = [] with open(file) as f: for line in f: line = line.strip().split(' ') X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) - if(get_pseudo_label): + if get_pseudo_label: Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) Y.append(int(line[2])) - if(get_pseudo_label): + if get_pseudo_label: return X, Z, Y else: return X, None, Y def get_mnist_add(train = True, get_pseudo_label = False): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform) + img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform) - if(train): + if train: file = './datasets/mnist_add/train_data.txt' else: file = './datasets/mnist_add/test_data.txt' diff --git a/models/lenet5.py b/models/lenet5.py index 56c1ca6..1016d2c 100644 --- a/models/lenet5.py +++ b/models/lenet5.py @@ -52,48 +52,32 @@ class LeNet5(nn.Module): return x def num_flat_features(self, x): - #x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size - size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字 + size = x.size()[1:] num_features = 1 for s in size: num_features *= s return num_features -class Params: - imgH = 28 - imgW = 28 - keep_ratio = True - saveInterval = 10 - batchSize = 16 - num_workers = 16 -def get_data(): #数据预处理 - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5), (0.5))]) - #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - #训练集 - train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) - train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16) - #测试集 - test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True) - test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16) - classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') - - return train_loader, test_loader, classes - -if __name__ == "__main__": - recorder = plog.ResultRecorder() - cls = LeNet5() - criterion = nn.CrossEntropyLoss(size_average=True) - optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder) - - train_loader, test_loader, classes = get_data() - - #model.val(test_loader, print_prefix = "before training") - model.fit(train_loader, n_epoch = 100) - model.val(test_loader, print_prefix = "after trained") - res = model.predict(test_loader, print_prefix = "predict") - print(res.argmax(axis=1)[:10]) +class SymbolNet(nn.Module): + def __init__(self, num_classes=14): + super(SymbolNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(30976, 128) + self.fc2 = nn.Linear(128, num_classes) + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + return x