| @@ -119,7 +119,7 @@ class add_KB(ClsKB): | |||||
| class hwf_KB(ClsKB): | class hwf_KB(ClsKB): | ||||
| def __init__(self, GKB_flag = False, \ | def __init__(self, GKB_flag = False, \ | ||||
| pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ | |||||
| pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ | |||||
| len_list = [1, 3, 5, 7]): | len_list = [1, 3, 5, 7]): | ||||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | super().__init__(GKB_flag, pseudo_label_list, len_list) | ||||
| @@ -127,7 +127,7 @@ class hwf_KB(ClsKB): | |||||
| if len(formula) % 2 == 0: | if len(formula) % 2 == 0: | ||||
| return False | return False | ||||
| for i in range(len(formula)): | for i in range(len(formula)): | ||||
| if i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||||
| return False | return False | ||||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | ||||
| return False | return False | ||||
| @@ -136,12 +136,9 @@ class hwf_KB(ClsKB): | |||||
| def logic_forward(self, formula): | def logic_forward(self, formula): | ||||
| if not self.valid_candidate(formula): | if not self.valid_candidate(formula): | ||||
| return np.inf | return np.inf | ||||
| try: | |||||
| mapping = {'0':'0', '1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'} | |||||
| formula = [mapping[f] for f in formula] | |||||
| return round(eval(''.join(formula)), 2) | |||||
| except ZeroDivisionError: | |||||
| return np.inf | |||||
| mapping = {'1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'} | |||||
| formula = [mapping[f] for f in formula] | |||||
| return round(eval(''.join(formula)), 2) | |||||
| class RegKB(KBBase): | class RegKB(KBBase): | ||||