diff --git a/abducer/kb.py b/abducer/kb.py index 5d25a46..12772e5 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -45,9 +45,9 @@ class KBBase(ABC): pass class add_KB(KBBase): - def __init__(self, pseudo_label_list, kb_max_len = -1): + def __init__(self, kb_max_len = -1): super().__init__() - self.pseudo_label_list = pseudo_label_list + self.pseudo_label_list = list(range(10)) self.base = {} self.kb_max_len = kb_max_len if(self.kb_max_len > 0): @@ -92,9 +92,9 @@ class add_KB(KBBase): return sum(self._dict_len(v) for v in self.base.values()) class hwf_KB(KBBase): - def __init__(self, pseudo_label_list, kb_max_len = -1): + def __init__(self, kb_max_len = -1): super().__init__() - self.pseudo_label_list = pseudo_label_list + self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] self.base = {} self.kb_max_len = kb_max_len if(self.kb_max_len > 0): @@ -157,13 +157,13 @@ class hwf_KB(KBBase): def valid_formula(self, formula): symbol_idx_list = [] - first_minus_flag = 0 for idx, c in enumerate(formula): if(idx == 0 and c == '-'): - first_minus_flag = 1 + if(len(formula) == 1 or formula[1] in ['+', '-', '*', '/']): + return False continue if(c in ['+', '-', '*', '/']): - if(idx - 1 in symbol_idx_list or (idx == 1 and first_minus_flag == 1)): + if(idx - 1 in symbol_idx_list): return False symbol_idx_list.append(idx) if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list): @@ -289,8 +289,7 @@ class reg_KB(KBBase): if __name__ == "__main__": # With ground KB - pseudo_label_list = list(range(10)) - kb = add_KB(pseudo_label_list, kb_max_len = 5) + kb = add_KB(kb_max_len = 5) print('len(kb):', len(kb)) res = kb.get_candidates(0) print(res) @@ -303,8 +302,7 @@ if __name__ == "__main__": print() # Without ground KB - pseudo_label_list = list(range(10)) - kb = add_KB(pseudo_label_list) + kb = add_KB() print('len(kb):', len(kb)) res = kb.get_candidates(0) print(res) @@ -316,8 +314,7 @@ if __name__ == "__main__": print(res) print() - pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '+', '-', '*', '/'] - kb = hwf_KB(pseudo_label_list, kb_max_len = 5) + kb = hwf_KB(kb_max_len = 5) print('len(kb):', len(kb)) res = kb.get_candidates(1, length = 3) print(res)