| @@ -91,52 +91,120 @@ class add_KB(KBBase): | |||
| def __len__(self): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| # class hwf_KB(KBBase): | |||
| # def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
| # super().__init__() | |||
| # self.pseudo_label_list = pseudo_label_list | |||
| # self.base = {} | |||
| # self.kb_max_len = kb_max_len | |||
| # if(self.kb_max_len > 0): | |||
| # X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
| # Y = self.get_Y(X, self.logic_forward) | |||
| # for x, y in zip(X, Y): | |||
| # self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| class hwf_KB(KBBase): | |||
| def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.base = {} | |||
| self.kb_max_len = kb_max_len | |||
| if(self.kb_max_len > 0): | |||
| X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
| Y = self.get_Y(X, self.logic_forward) | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| # def logic_forward(self, nums): | |||
| # return sum(nums) | |||
| def calculate(self, formula): | |||
| stack = [] | |||
| postfix = [] | |||
| priority = {'+': 0, '-': 0, | |||
| '*': 1, '/': 1} | |||
| skip_flag = 0 | |||
| for i in range(len(formula)): | |||
| if formula[i] == '-': | |||
| if i == 0: | |||
| formula.insert(0, 0) | |||
| for i in range(len(formula)): | |||
| if skip_flag: | |||
| skip_flag -= 1 | |||
| continue | |||
| char = formula[i] | |||
| if char in priority.keys(): | |||
| while stack and (priority[char] <= priority[stack[-1]]): | |||
| postfix.append(stack.pop()) | |||
| stack.append(char) | |||
| else: | |||
| num = char | |||
| while (i + 1) < len(formula): | |||
| if formula[i + 1] not in priority.keys(): | |||
| skip_flag += 1 | |||
| num = num * 10 + formula[i + 1] | |||
| i += 1 | |||
| else: | |||
| break | |||
| postfix.append(num) | |||
| while stack: | |||
| postfix.append(stack.pop()) | |||
| for i in postfix: | |||
| if i in priority.keys(): | |||
| num2 = stack.pop() | |||
| num1 = stack.pop() | |||
| if i == '+': | |||
| res = num1 + num2 | |||
| elif i == '-': | |||
| res = num1 - num2 | |||
| elif i == '*': | |||
| res = num1 * num2 | |||
| elif i == '/': | |||
| if(num2 == 0): | |||
| return np.inf | |||
| res = num1 / num2 | |||
| stack.append(res) | |||
| else: | |||
| stack.append(i) | |||
| return round(stack[0], 2) | |||
| # def get_X(self, pseudo_label_list, max_len): | |||
| # res = [] | |||
| # assert(max_len >= 2) | |||
| # for len in range(2, max_len + 1): | |||
| # res += list(product(pseudo_label_list, repeat = len)) | |||
| # return res | |||
| # def get_Y(self, X, logic_forward): | |||
| # return [logic_forward(nums) for nums in X] | |||
| # def get_candidates(self, key, length = None): | |||
| # if(self.base == {}): | |||
| # return [] | |||
| def valid_formula(self, formula): | |||
| symbol_idx_list = [] | |||
| first_minus_flag = 0 | |||
| for idx, c in enumerate(formula): | |||
| if(idx == 0 and c == '-'): | |||
| first_minus_flag = 1 | |||
| continue | |||
| if(c in ['+', '-', '*', '/']): | |||
| if(idx - 1 in symbol_idx_list or (idx == 1 and first_minus_flag == 1)): | |||
| return False | |||
| symbol_idx_list.append(idx) | |||
| if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list): | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if(self.valid_formula(formula) == False): | |||
| return np.inf | |||
| return self.calculate(list(formula)) | |||
| # if key is None: | |||
| # return self.get_all_candidates() | |||
| def get_X(self, pseudo_label_list, max_len): | |||
| res = [] | |||
| assert(max_len >= 2) | |||
| for len in range(2, max_len + 1): | |||
| res += list(product(pseudo_label_list, repeat = len)) | |||
| return res | |||
| # length = self._length(length) | |||
| # if(self.kb_max_len < min(length)): | |||
| # return [] | |||
| # return sum([self.base[l][key] for l in length], []) | |||
| def get_Y(self, X, logic_forward): | |||
| return [logic_forward(formula) for formula in X] | |||
| def get_candidates(self, key, length = None): | |||
| if(self.base == {}): | |||
| return [] | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| if(self.kb_max_len < min(length)): | |||
| return [] | |||
| return sum([self.base[l][key] for l in length], []) | |||
| # def get_all_candidates(self): | |||
| # return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def get_all_candidates(self): | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| # def _dict_len(self, dic): | |||
| # return sum(len(c) for c in dic.values()) | |||
| def _dict_len(self, dic): | |||
| return sum(len(c) for c in dic.values()) | |||
| # def __len__(self): | |||
| # return sum(self._dict_len(v) for v in self.base.values()) | |||
| def __len__(self): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class cls_KB(KBBase): | |||
| def __init__(self, X, Y = None): | |||
| @@ -248,10 +316,14 @@ if __name__ == "__main__": | |||
| print(res) | |||
| print() | |||
| # pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] | |||
| # kb = hwf_KB(pseudo_label_list, max_len = 5) | |||
| # print('len(kb):', len(kb)) | |||
| # print() | |||
| pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] | |||
| kb = hwf_KB(pseudo_label_list, kb_max_len = 5) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(1, length = 3) | |||
| print(res) | |||
| res = kb.get_candidates(3.67, length = 5) | |||
| print(res) | |||
| print() | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||