Browse Source

Merge branch 'Dev' of https://github.com/AbductiveLearning/ABL-Package into Dev

pull/3/head
Gao Enhao 3 years ago
parent
commit
6ea5d8f9ec
5 changed files with 40 additions and 64 deletions
  1. +1
    -0
      abducer/abducer_base.py
  2. +5
    -13
      abducer/kb.py
  3. +7
    -8
      datasets/hwf/get_hwf.py
  4. +5
    -5
      datasets/mnist_add/get_mnist_add.py
  5. +22
    -38
      models/lenet5.py

+ 1
- 0
abducer/abducer_base.py View File

@@ -11,6 +11,7 @@
#================================================================#

import abc
# from kb import add_KB, hwf_KB
from abducer.kb import add_KB, hwf_KB
import numpy as np



+ 5
- 13
abducer/kb.py View File

@@ -34,10 +34,6 @@ class KBBase(ABC):
def logic_forward(self):
pass
@abstractmethod
def valid_candidate(self):
pass
def _length(self, length):
if length is None:
length = list(self.base.keys())
@@ -57,6 +53,7 @@ class ClsKB(KBBase):
self.len_list = len_list
if GKB_flag:
# self.base = np.load('abducer/hwf.npy', allow_pickle=True).item()
self.base = {}
X, Y = self.get_GKB(self.pseudo_label_list, self.len_list)
for x, y in zip(X, Y):
@@ -70,14 +67,12 @@ class ClsKB(KBBase):
X = []
Y = []
for x in all_X:
if self.valid_candidate(x):
y = self.logic_forward(x)
if y != np.inf:
X.append(x)
Y.append(self.logic_forward(x))
Y.append(y)
return X, Y
def valid_candidate(self):
pass
def logic_forward(self):
pass

@@ -88,7 +83,7 @@ class ClsKB(KBBase):
if key is None:
return self.get_all_candidates()
if (type(length) is int and length not in self.len_list):
if type(length) is int and length not in self.len_list:
return []
length = self._length(length)
return sum([self.base[l][key] for l in length], [])
@@ -117,9 +112,6 @@ class add_KB(ClsKB):
pseudo_label_list = list(range(10)), \
len_list = [2]):
super().__init__(GKB_flag, pseudo_label_list, len_list)
def valid_candidate(self, x):
return True
def logic_forward(self, nums):
return sum(nums)


+ 7
- 8
datasets/hwf/get_hwf.py View File

@@ -9,7 +9,7 @@ img_transform = transforms.Compose([

def get_data(file, get_pseudo_label, precision_num = 2):
X = []
if(get_pseudo_label):
if get_pseudo_label:
Z = []
Y = []
img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/'
@@ -22,15 +22,14 @@ def get_data(file, get_pseudo_label, precision_num = 2):
img = Image.open(img_dir + img_path).convert('L')
img = img_transform(img)
imgs.append(img)
if(get_pseudo_label):
if get_pseudo_label:
imgs_pseudo_label.append(img_path.split('/')[0])
if(len(imgs) == 3):
X.append(imgs)
if(get_pseudo_label):
Z.append(imgs_pseudo_label)
Y.append(round(data[idx]['res'], precision_num))
X.append(imgs)
if get_pseudo_label:
Z.append(imgs_pseudo_label)
Y.append(round(data[idx]['res'], precision_num))
if(get_pseudo_label):
if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y


+ 5
- 5
datasets/mnist_add/get_mnist_add.py View File

@@ -5,27 +5,27 @@ from torchvision.transforms import transforms

def get_data(file, img_dataset, get_pseudo_label):
X = []
if(get_pseudo_label):
if get_pseudo_label:
Z = []
Y = []
with open(file) as f:
for line in f:
line = line.strip().split(' ')
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]])
if(get_pseudo_label):
if get_pseudo_label:
Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]])
Y.append(int(line[2]))
if(get_pseudo_label):
if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y

def get_mnist_add(train = True, get_pseudo_label = False):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform)
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform)
if(train):
if train:
file = './datasets/mnist_add/train_data.txt'
else:
file = './datasets/mnist_add/test_data.txt'


+ 22
- 38
models/lenet5.py View File

@@ -52,48 +52,32 @@ class LeNet5(nn.Module):
return x

def num_flat_features(self, x):
#x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size
size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features

class Params:
imgH = 28
imgW = 28
keep_ratio = True
saveInterval = 10
batchSize = 16
num_workers = 16

def get_data(): #数据预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))])
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#训练集
train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16)
#测试集
test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

return train_loader, test_loader, classes

if __name__ == "__main__":
recorder = plog.ResultRecorder()
cls = LeNet5()
criterion = nn.CrossEntropyLoss(size_average=True)
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder)
train_loader, test_loader, classes = get_data()

#model.val(test_loader, print_prefix = "before training")
model.fit(train_loader, n_epoch = 100)
model.val(test_loader, print_prefix = "after trained")
res = model.predict(test_loader, print_prefix = "predict")
print(res.argmax(axis=1)[:10])
class SymbolNet(nn.Module):
def __init__(self, num_classes=14):
super(SymbolNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1)
self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(30976, 128)
self.fc2 = nn.Linear(128, num_classes)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x

Loading…
Cancel
Save