| @@ -45,9 +45,9 @@ class KBBase(ABC): | |||
| pass | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
| def __init__(self, kb_max_len = -1): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.pseudo_label_list = list(range(10)) | |||
| self.base = {} | |||
| self.kb_max_len = kb_max_len | |||
| if(self.kb_max_len > 0): | |||
| @@ -92,9 +92,9 @@ class add_KB(KBBase): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class hwf_KB(KBBase): | |||
| def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
| def __init__(self, kb_max_len = -1): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] | |||
| self.base = {} | |||
| self.kb_max_len = kb_max_len | |||
| if(self.kb_max_len > 0): | |||
| @@ -157,13 +157,13 @@ class hwf_KB(KBBase): | |||
| def valid_formula(self, formula): | |||
| symbol_idx_list = [] | |||
| first_minus_flag = 0 | |||
| for idx, c in enumerate(formula): | |||
| if(idx == 0 and c == '-'): | |||
| first_minus_flag = 1 | |||
| if(len(formula) == 1 or formula[1] in ['+', '-', '*', '/']): | |||
| return False | |||
| continue | |||
| if(c in ['+', '-', '*', '/']): | |||
| if(idx - 1 in symbol_idx_list or (idx == 1 and first_minus_flag == 1)): | |||
| if(idx - 1 in symbol_idx_list): | |||
| return False | |||
| symbol_idx_list.append(idx) | |||
| if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list): | |||
| @@ -289,8 +289,7 @@ class reg_KB(KBBase): | |||
| if __name__ == "__main__": | |||
| # With ground KB | |||
| pseudo_label_list = list(range(10)) | |||
| kb = add_KB(pseudo_label_list, kb_max_len = 5) | |||
| kb = add_KB(kb_max_len = 5) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(0) | |||
| print(res) | |||
| @@ -303,8 +302,7 @@ if __name__ == "__main__": | |||
| print() | |||
| # Without ground KB | |||
| pseudo_label_list = list(range(10)) | |||
| kb = add_KB(pseudo_label_list) | |||
| kb = add_KB() | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(0) | |||
| print(res) | |||
| @@ -316,8 +314,7 @@ if __name__ == "__main__": | |||
| print(res) | |||
| print() | |||
| pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '+', '-', '*', '/'] | |||
| kb = hwf_KB(pseudo_label_list, kb_max_len = 5) | |||
| kb = hwf_KB(kb_max_len = 5) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(1, length = 3) | |||
| print(res) | |||