From 5c6cbd45c74c2099ca5fbc7dc9c88a49bdf19077 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 21 Nov 2022 16:50:11 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 7fa4b3f..8c3f49a 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -127,7 +127,7 @@ class add_KB(ClsKB): class hwf_KB(ClsKB): def __init__(self, GKB_flag = False, \ - pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'], \ + pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ len_list = [1, 3, 5, 7]): super().__init__(GKB_flag, pseudo_label_list, len_list) @@ -137,7 +137,7 @@ class hwf_KB(ClsKB): for i in range(len(formula)): if i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: return False - if i % 2 != 0 and formula[i] not in ['+', '-', '*', '/']: + if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: return False return True @@ -145,6 +145,8 @@ class hwf_KB(ClsKB): if not self.valid_candidate(formula): return np.inf try: + mapping = {'0':'0', '1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'} + formula = [mapping[f] for f in formula] return round(eval(''.join(formula)), 2) except ZeroDivisionError: return np.inf