| @@ -11,6 +11,7 @@ | |||
| #================================================================# | |||
| import abc | |||
| # from kb import add_KB, hwf_KB | |||
| from abducer.kb import add_KB, hwf_KB | |||
| import numpy as np | |||
| @@ -34,10 +34,6 @@ class KBBase(ABC): | |||
| def logic_forward(self): | |||
| pass | |||
| @abstractmethod | |||
| def valid_candidate(self): | |||
| pass | |||
| def _length(self, length): | |||
| if length is None: | |||
| length = list(self.base.keys()) | |||
| @@ -57,6 +53,7 @@ class ClsKB(KBBase): | |||
| self.len_list = len_list | |||
| if GKB_flag: | |||
| # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() | |||
| self.base = {} | |||
| X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) | |||
| for x, y in zip(X, Y): | |||
| @@ -70,14 +67,12 @@ class ClsKB(KBBase): | |||
| X = [] | |||
| Y = [] | |||
| for x in all_X: | |||
| if self.valid_candidate(x): | |||
| y = self.logic_forward(x) | |||
| if y != np.inf: | |||
| X.append(x) | |||
| Y.append(self.logic_forward(x)) | |||
| Y.append(y) | |||
| return X, Y | |||
| def valid_candidate(self): | |||
| pass | |||
| def logic_forward(self): | |||
| pass | |||
| @@ -88,7 +83,7 @@ class ClsKB(KBBase): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| if (type(length) is int and length not in self.len_list): | |||
| if type(length) is int and length not in self.len_list: | |||
| return [] | |||
| length = self._length(length) | |||
| return sum([self.base[l][key] for l in length], []) | |||
| @@ -117,9 +112,6 @@ class add_KB(ClsKB): | |||
| pseudo_label_list = list(range(10)), \ | |||
| len_list = [2]): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||
| def valid_candidate(self, x): | |||
| return True | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| @@ -9,7 +9,7 @@ img_transform = transforms.Compose([ | |||
| def get_data(file, get_pseudo_label, precision_num = 2): | |||
| X = [] | |||
| if(get_pseudo_label): | |||
| if get_pseudo_label: | |||
| Z = [] | |||
| Y = [] | |||
| img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/' | |||
| @@ -22,15 +22,14 @@ def get_data(file, get_pseudo_label, precision_num = 2): | |||
| img = Image.open(img_dir + img_path).convert('L') | |||
| img = img_transform(img) | |||
| imgs.append(img) | |||
| if(get_pseudo_label): | |||
| if get_pseudo_label: | |||
| imgs_pseudo_label.append(img_path.split('/')[0]) | |||
| if(len(imgs) == 3): | |||
| X.append(imgs) | |||
| if(get_pseudo_label): | |||
| Z.append(imgs_pseudo_label) | |||
| Y.append(round(data[idx]['res'], precision_num)) | |||
| X.append(imgs) | |||
| if get_pseudo_label: | |||
| Z.append(imgs_pseudo_label) | |||
| Y.append(round(data[idx]['res'], precision_num)) | |||
| if(get_pseudo_label): | |||
| if get_pseudo_label: | |||
| return X, Z, Y | |||
| else: | |||
| return X, None, Y | |||
| @@ -5,27 +5,27 @@ from torchvision.transforms import transforms | |||
| def get_data(file, img_dataset, get_pseudo_label): | |||
| X = [] | |||
| if(get_pseudo_label): | |||
| if get_pseudo_label: | |||
| Z = [] | |||
| Y = [] | |||
| with open(file) as f: | |||
| for line in f: | |||
| line = line.strip().split(' ') | |||
| X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) | |||
| if(get_pseudo_label): | |||
| if get_pseudo_label: | |||
| Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) | |||
| Y.append(int(line[2])) | |||
| if(get_pseudo_label): | |||
| if get_pseudo_label: | |||
| return X, Z, Y | |||
| else: | |||
| return X, None, Y | |||
| def get_mnist_add(train = True, get_pseudo_label = False): | |||
| transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) | |||
| img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform) | |||
| img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform) | |||
| if(train): | |||
| if train: | |||
| file = './datasets/mnist_add/train_data.txt' | |||
| else: | |||
| file = './datasets/mnist_add/test_data.txt' | |||
| @@ -52,48 +52,32 @@ class LeNet5(nn.Module): | |||
| return x | |||
| def num_flat_features(self, x): | |||
| #x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size | |||
| size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字 | |||
| size = x.size()[1:] | |||
| num_features = 1 | |||
| for s in size: | |||
| num_features *= s | |||
| return num_features | |||
| class Params: | |||
| imgH = 28 | |||
| imgW = 28 | |||
| keep_ratio = True | |||
| saveInterval = 10 | |||
| batchSize = 16 | |||
| num_workers = 16 | |||
| def get_data(): #数据预处理 | |||
| transform = transforms.Compose([transforms.ToTensor(), | |||
| transforms.Normalize((0.5), (0.5))]) | |||
| #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |||
| #训练集 | |||
| train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) | |||
| train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16) | |||
| #测试集 | |||
| test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True) | |||
| test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16) | |||
| classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') | |||
| return train_loader, test_loader, classes | |||
| if __name__ == "__main__": | |||
| recorder = plog.ResultRecorder() | |||
| cls = LeNet5() | |||
| criterion = nn.CrossEntropyLoss(size_average=True) | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder) | |||
| train_loader, test_loader, classes = get_data() | |||
| #model.val(test_loader, print_prefix = "before training") | |||
| model.fit(train_loader, n_epoch = 100) | |||
| model.val(test_loader, print_prefix = "after trained") | |||
| res = model.predict(test_loader, print_prefix = "predict") | |||
| print(res.argmax(axis=1)[:10]) | |||
| class SymbolNet(nn.Module): | |||
| def __init__(self, num_classes=14): | |||
| super(SymbolNet, self).__init__() | |||
| self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) | |||
| self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) | |||
| self.dropout1 = nn.Dropout2d(0.25) | |||
| self.dropout2 = nn.Dropout2d(0.5) | |||
| self.fc1 = nn.Linear(30976, 128) | |||
| self.fc2 = nn.Linear(128, num_classes) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = F.relu(x) | |||
| x = self.conv2(x) | |||
| x = F.max_pool2d(x, 2) | |||
| x = self.dropout1(x) | |||
| x = torch.flatten(x, 1) | |||
| x = self.fc1(x) | |||
| x = F.relu(x) | |||
| x = self.dropout2(x) | |||
| x = self.fc2(x) | |||
| return x | |||