From 4a1de111f96d8876f6bd43a69a7658da14cb9221 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 21 Nov 2022 15:00:58 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index c206f62..499d9bb 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -60,7 +60,7 @@ class ClsKB(KBBase): if GKB_flag: X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) + self.base.setdefault(len(x), defaultdict(list))[y].append(x) def get_GKB(self, pseudo_label_list, len_list): all_X = [] @@ -104,9 +104,10 @@ class ClsKB(KBBase): class add_KB(ClsKB): - def __init__(self, GKB_flag = False, len_list = [2]): - self.pseudo_label_list = list(range(10)) - super().__init__(GKB_flag, self.pseudo_label_list, len_list) + def __init__(self, GKB_flag = False, \ + pseudo_label_list = list(range(10)), \ + len_list = [2]): + super().__init__(GKB_flag, pseudo_label_list, len_list) def valid_candidate(self, x): return True @@ -116,9 +117,10 @@ class add_KB(ClsKB): class hwf_KB(ClsKB): - def __init__(self, GKB_flag = False, len_list = [1, 3, 5, 7]): - self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] - super().__init__(GKB_flag, self.pseudo_label_list, len_list) + def __init__(self, GKB_flag = False, \ + pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'], \ + len_list = [1, 3, 5, 7]): + super().__init__(GKB_flag, pseudo_label_list, len_list) def valid_candidate(self, formula): if(len(formula) % 2 == 0):