Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
4a1de111f9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 7 deletions
  1. +9
    -7
      abducer/kb.py

+ 9
- 7
abducer/kb.py View File

@@ -60,7 +60,7 @@ class ClsKB(KBBase):
if GKB_flag:
X, Y = self.get_GKB(self.pseudo_label_list, self.len_list)
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
def get_GKB(self, pseudo_label_list, len_list):
all_X = []
@@ -104,9 +104,10 @@ class ClsKB(KBBase):


class add_KB(ClsKB):
def __init__(self, GKB_flag = False, len_list = [2]):
self.pseudo_label_list = list(range(10))
super().__init__(GKB_flag, self.pseudo_label_list, len_list)
def __init__(self, GKB_flag = False, \
pseudo_label_list = list(range(10)), \
len_list = [2]):
super().__init__(GKB_flag, pseudo_label_list, len_list)
def valid_candidate(self, x):
return True
@@ -116,9 +117,10 @@ class add_KB(ClsKB):
class hwf_KB(ClsKB):
def __init__(self, GKB_flag = False, len_list = [1, 3, 5, 7]):
self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/']
super().__init__(GKB_flag, self.pseudo_label_list, len_list)
def __init__(self, GKB_flag = False, \
pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'], \
len_list = [1, 3, 5, 7]):
super().__init__(GKB_flag, pseudo_label_list, len_list)
def valid_candidate(self, formula):
if(len(formula) % 2 == 0):


Loading…
Cancel
Save