| @@ -19,7 +19,7 @@ from collections import defaultdict | |||
| from itertools import product | |||
| class KBBase(ABC): | |||
| def __init__(self): | |||
| def __init__(self, GKB_flag = False): | |||
| pass | |||
| @abstractmethod | |||
| @@ -46,13 +46,13 @@ class KBBase(ABC): | |||
| class ClsKB(KBBase): | |||
| def __init__(self, pseudo_label_list, len_list = None): | |||
| def __init__(self, GKB_flag = False, pseudo_label_list = None, len_list = None): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.base = {} | |||
| self.len_list = len_list | |||
| if(self.len_list != None): | |||
| if GKB_flag: | |||
| X = self.get_X(self.pseudo_label_list, self.len_list) | |||
| Y = self.get_Y(X, self.logic_forward) | |||
| for x, y in zip(X, Y): | |||
| @@ -94,18 +94,18 @@ class ClsKB(KBBase): | |||
| class add_KB(ClsKB): | |||
| def __init__(self, len_list = [2]): | |||
| def __init__(self, GKB_flag = False, len_list = [2]): | |||
| self.pseudo_label_list = list(range(10)) | |||
| super().__init__(self.pseudo_label_list, len_list) | |||
| super().__init__(GKB_flag, self.pseudo_label_list, len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class hwf_KB(ClsKB): | |||
| def __init__(self, len_list = [1, 3, 5, 7]): | |||
| def __init__(self, GKB_flag = False, len_list = [1, 3, 5, 7]): | |||
| self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] | |||
| super().__init__(self.pseudo_label_list, len_list) | |||
| super().__init__(GKB_flag, self.pseudo_label_list, len_list) | |||
| def valid_formula(self, formula): | |||
| if(len(formula) % 2 == 0): | |||
| @@ -127,7 +127,7 @@ class hwf_KB(ClsKB): | |||
| class RegKB(KBBase): | |||
| def __init__(self, X, Y = None): | |||
| def __init__(self, GKB_flag = False, X = None, Y = None): | |||
| super().__init__() | |||
| tmp_dict = {} | |||
| for x, y in zip(X, Y): | |||
| @@ -176,7 +176,7 @@ class RegKB(KBBase): | |||
| import time | |||
| if __name__ == "__main__": | |||
| # With ground KB | |||
| kb = add_KB(len_list = [2]) | |||
| kb = add_KB(GKB_flag = True) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(0) | |||
| print(res) | |||
| @@ -202,7 +202,7 @@ if __name__ == "__main__": | |||
| print() | |||
| start = time.time() | |||
| kb = hwf_KB(len_list = [1, 3, 5, 7]) | |||
| kb = hwf_KB(GKB_flag = True) | |||
| print(time.time() - start) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(2, length = 1) | |||