| @@ -44,12 +44,14 @@ class KBBase(ABC): | |||
| def __len__(self): | |||
| pass | |||
| class add_KB(KBBase): | |||
| def __init__(self, kb_max_len = -1): | |||
| class ClsKB(KBBase): | |||
| def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
| super().__init__() | |||
| self.pseudo_label_list = list(range(10)) | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.base = {} | |||
| self.kb_max_len = kb_max_len | |||
| if(self.kb_max_len > 0): | |||
| X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
| Y = self.get_Y(X, self.logic_forward) | |||
| @@ -57,9 +59,6 @@ class add_KB(KBBase): | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def get_X(self, pseudo_label_list, max_len): | |||
| res = [] | |||
| assert(max_len >= 2) | |||
| @@ -69,6 +68,9 @@ class add_KB(KBBase): | |||
| def get_Y(self, X, logic_forward): | |||
| return [logic_forward(nums) for nums in X] | |||
| def logic_forward(self): | |||
| return None | |||
| def get_candidates(self, key, length = None): | |||
| if(self.base == {}): | |||
| @@ -76,7 +78,7 @@ class add_KB(KBBase): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| if(self.kb_max_len < min(length)): | |||
| return [] | |||
| @@ -84,163 +86,72 @@ class add_KB(KBBase): | |||
| def get_all_candidates(self): | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def _dict_len(self, dic): | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class hwf_KB(KBBase): | |||
| class add_KB(ClsKB): | |||
| def __init__(self, kb_max_len = -1): | |||
| super().__init__() | |||
| self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] | |||
| self.base = {} | |||
| self.kb_max_len = kb_max_len | |||
| if(self.kb_max_len > 0): | |||
| X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
| Y = self.get_Y(X, self.logic_forward) | |||
| self.pseudo_label_list = list(range(10)) | |||
| super().__init__(self.pseudo_label_list, kb_max_len) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| def get_candidates(self, key, length = None): | |||
| return super().get_candidates(key, length) | |||
| def calculate(self, formula): | |||
| stack = [] | |||
| postfix = [] | |||
| priority = {'+': 0, '-': 0, | |||
| '*': 1, '/': 1} | |||
| skip_flag = 0 | |||
| for i in range(len(formula)): | |||
| if formula[i] == '-': | |||
| if i == 0: | |||
| formula.insert(0, 0) | |||
| for i in range(len(formula)): | |||
| if skip_flag: | |||
| skip_flag -= 1 | |||
| continue | |||
| char = formula[i] | |||
| if char in priority.keys(): | |||
| while stack and (priority[char] <= priority[stack[-1]]): | |||
| postfix.append(stack.pop()) | |||
| stack.append(char) | |||
| else: | |||
| num = int(char) | |||
| while (i + 1) < len(formula): | |||
| if formula[i + 1] not in priority.keys(): | |||
| skip_flag += 1 | |||
| num = num * 10 + int(formula[i + 1]) | |||
| i += 1 | |||
| else: | |||
| break | |||
| postfix.append(num) | |||
| while stack: | |||
| postfix.append(stack.pop()) | |||
| def get_all_candidates(self): | |||
| return super().get_all_candidates() | |||
| def _dict_len(self, dic): | |||
| return super()._dict_len(dic) | |||
| for i in postfix: | |||
| if i in priority.keys(): | |||
| num2 = stack.pop() | |||
| num1 = stack.pop() | |||
| if i == '+': | |||
| res = num1 + num2 | |||
| elif i == '-': | |||
| res = num1 - num2 | |||
| elif i == '*': | |||
| res = num1 * num2 | |||
| elif i == '/': | |||
| if(num2 == 0): | |||
| return np.inf | |||
| res = num1 / num2 | |||
| stack.append(res) | |||
| else: | |||
| stack.append(i) | |||
| return round(stack[0], 2) | |||
| def __len__(self): | |||
| return super().__len__() | |||
| class hwf_KB(ClsKB): | |||
| def __init__(self, kb_max_len = -1): | |||
| self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] | |||
| super().__init__(self.pseudo_label_list, kb_max_len) | |||
| def valid_formula(self, formula): | |||
| symbol_idx_list = [] | |||
| for idx, c in enumerate(formula): | |||
| if(idx == 0 and c == '-'): | |||
| if(len(formula) == 1 or formula[1] in ['+', '-', '*', '/']): | |||
| return False | |||
| continue | |||
| if(c in ['+', '-', '*', '/']): | |||
| if(idx - 1 in symbol_idx_list): | |||
| return False | |||
| symbol_idx_list.append(idx) | |||
| if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list): | |||
| if(len(formula) % 2 == 0): | |||
| return False | |||
| for i in range(len(formula)): | |||
| if(i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']): | |||
| return False | |||
| if(i % 2 != 0 and formula[i] not in ['+', '-', '*', '/']): | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if(self.valid_formula(formula) == False): | |||
| return np.inf | |||
| return self.calculate(list(formula)) | |||
| try: | |||
| return eval(''.join(formula)) | |||
| except ZeroDivisionError: | |||
| return np.inf | |||
| def get_X(self, pseudo_label_list, max_len): | |||
| res = [] | |||
| assert(max_len >= 2) | |||
| for len in range(2, max_len + 1): | |||
| res += list(product(pseudo_label_list, repeat = len)) | |||
| return res | |||
| def get_Y(self, X, logic_forward): | |||
| return [logic_forward(formula) for formula in X] | |||
| def get_candidates(self, key, length = None): | |||
| if(self.base == {}): | |||
| return [] | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| if(self.kb_max_len < min(length)): | |||
| return [] | |||
| return sum([self.base[l][key] for l in length], []) | |||
| return super().get_candidates(key, length) | |||
| def get_all_candidates(self): | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def _dict_len(self, dic): | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class cls_KB(KBBase): | |||
| def __init__(self, X, Y = None): | |||
| super().__init__() | |||
| self.base = {} | |||
| if X is None: | |||
| return | |||
| if Y is None: | |||
| Y = [None] * len(X) | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| return super().get_all_candidates() | |||
| def logic_forward(self): | |||
| return None | |||
| def get_candidates(self, key, length = None): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| return sum([self.base[l][key] for l in length], []) | |||
| def get_all_candidates(self): | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def _dict_len(self, dic): | |||
| return sum(len(c) for c in dic.values()) | |||
| return super()._dict_len(dic) | |||
| def __len__(self): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| return super().__len__() | |||
| class reg_KB(KBBase): | |||
| class RegKB(KBBase): | |||
| def __init__(self, X, Y = None): | |||
| super().__init__() | |||
| tmp_dict = {} | |||
| @@ -323,26 +234,26 @@ if __name__ == "__main__": | |||
| print() | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
| Y = [2, 1, 1, 2, 2] | |||
| kb = cls_KB(X, Y) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(2, 5) | |||
| print(res) | |||
| res = kb.get_candidates(2, 3) | |||
| print(res) | |||
| res = kb.get_candidates(None) | |||
| print(res) | |||
| print() | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
| # Y = [2, 1, 1, 2, 2] | |||
| # kb = ClsKB(X, Y) | |||
| # print('len(kb):', len(kb)) | |||
| # res = kb.get_candidates(2, 5) | |||
| # print(res) | |||
| # res = kb.get_candidates(2, 3) | |||
| # print(res) | |||
| # res = kb.get_candidates(None) | |||
| # print(res) | |||
| # print() | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
| Y = [2, 1, 1, 2, 1.5, 1.5] | |||
| kb = reg_KB(X, Y) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(1.6) | |||
| print(res) | |||
| res = kb.get_candidates(1.6, length = 9) | |||
| print(res) | |||
| res = kb.get_candidates(None) | |||
| print(res) | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
| # Y = [2, 1, 1, 2, 1.5, 1.5] | |||
| # kb = RegKB(X, Y) | |||
| # print('len(kb):', len(kb)) | |||
| # res = kb.get_candidates(1.6) | |||
| # print(res) | |||
| # res = kb.get_candidates(1.6, length = 9) | |||
| # print(res) | |||
| # res = kb.get_candidates(None) | |||
| # print(res) | |||