diff --git a/abducer/kb.py b/abducer/kb.py index 7fa4b3f..8c3f49a 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -127,7 +127,7 @@ class add_KB(ClsKB): class hwf_KB(ClsKB): def __init__(self, GKB_flag = False, \ - pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'], \ + pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ len_list = [1, 3, 5, 7]): super().__init__(GKB_flag, pseudo_label_list, len_list) @@ -137,7 +137,7 @@ class hwf_KB(ClsKB): for i in range(len(formula)): if i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: return False - if i % 2 != 0 and formula[i] not in ['+', '-', '*', '/']: + if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: return False return True @@ -145,6 +145,8 @@ class hwf_KB(ClsKB): if not self.valid_candidate(formula): return np.inf try: + mapping = {'0':'0', '1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'} + formula = [mapping[f] for f in formula] return round(eval(''.join(formula)), 2) except ZeroDivisionError: return np.inf