From 15eda317bcf53bf2ecb15ca6096baceed0de1593 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 21 Nov 2022 09:46:31 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index a934ad8..98ea9bc 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -19,7 +19,7 @@ from collections import defaultdict from itertools import product class KBBase(ABC): - def __init__(self): + def __init__(self, GKB_flag = False): pass @abstractmethod @@ -46,13 +46,13 @@ class KBBase(ABC): class ClsKB(KBBase): - def __init__(self, pseudo_label_list, len_list = None): + def __init__(self, GKB_flag = False, pseudo_label_list = None, len_list = None): super().__init__() self.pseudo_label_list = pseudo_label_list self.base = {} self.len_list = len_list - if(self.len_list != None): + if GKB_flag: X = self.get_X(self.pseudo_label_list, self.len_list) Y = self.get_Y(X, self.logic_forward) for x, y in zip(X, Y): @@ -94,18 +94,18 @@ class ClsKB(KBBase): class add_KB(ClsKB): - def __init__(self, len_list = [2]): + def __init__(self, GKB_flag = False, len_list = [2]): self.pseudo_label_list = list(range(10)) - super().__init__(self.pseudo_label_list, len_list) + super().__init__(GKB_flag, self.pseudo_label_list, len_list) def logic_forward(self, nums): return sum(nums) class hwf_KB(ClsKB): - def __init__(self, len_list = [1, 3, 5, 7]): + def __init__(self, GKB_flag = False, len_list = [1, 3, 5, 7]): self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] - super().__init__(self.pseudo_label_list, len_list) + super().__init__(GKB_flag, self.pseudo_label_list, len_list) def valid_formula(self, formula): if(len(formula) % 2 == 0): @@ -127,7 +127,7 @@ class hwf_KB(ClsKB): class RegKB(KBBase): - def __init__(self, X, Y = None): + def __init__(self, GKB_flag = False, X = None, Y = None): super().__init__() tmp_dict = {} for x, y in zip(X, Y): @@ -176,7 +176,7 @@ class RegKB(KBBase): import time if __name__ == "__main__": # With ground KB - kb = add_KB(len_list = [2]) + kb = add_KB(GKB_flag = True) print('len(kb):', len(kb)) res = kb.get_candidates(0) print(res) @@ -202,7 +202,7 @@ if __name__ == "__main__": print() start = time.time() - kb = hwf_KB(len_list = [1, 3, 5, 7]) + kb = hwf_KB(GKB_flag = True) print(time.time() - start) print('len(kb):', len(kb)) res = kb.get_candidates(2, length = 1)