From bfd6dc8a5c37c680c683a3c3d9c4efe756853248 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Tue, 7 Mar 2023 00:01:08 +0800 Subject: [PATCH] update TODO in kb.py --- abl/abducer/kb.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index d937dcf..9ae919a 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -80,6 +80,7 @@ class KBBase(ABC): res = [self.logic_forward(x) for x in xs] return res + # TODO:这里max_address_num默认值-1,后面运行会有问题吗 def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) @@ -134,8 +135,10 @@ class KBBase(ABC): candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] return candidates, min_address_num, address_num + # TODO:应该也是内部使用的方法? def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] + # TODO:product combinations本身就是迭代器,如果没有其他用途,不用转list,直接放到循环那即可,省去一些时间,下面的同理 abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) if multiple_predictions: @@ -209,7 +212,8 @@ class ClsKB(KBBase): def __init__(self, pseudo_label_list, len_list, GKB_flag): super().__init__(pseudo_label_list, len_list, GKB_flag) - def logic_forward(self): + # TODO:这里以及RegKB可以不实现logic_forward吗,这样用户继承后不实现logic_forward就会报错 + def logic_forward(self, pseudo_labels): pass def _find_candidate_GKB(self, pred_res, key): @@ -243,13 +247,15 @@ class prolog_KB(KBBase): address_pred_res = pred_res.copy() if multiple_predictions: address_pred_res = flatten(address_pred_res) - + + # TODO:可以直接对address_idx循环? for idx in range(len(address_pred_res)): if idx in address_idx: address_pred_res[idx] = 'P' + str(idx) if multiple_predictions: address_pred_res = reform_idx(address_pred_res, pred_res) + # TODO:不知道有没有更简洁的方法 regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) @@ -269,6 +275,7 @@ class prolog_KB(KBBase): if multiple_predictions: save_pred_res = pred_res pred_res = flatten(pred_res) + # TODO:这里后面的那个list应该也不需要 abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] for c in abduce_c: candidate = pred_res.copy() @@ -289,17 +296,15 @@ class prolog_KB(KBBase): if len(prolog_result) == 0: return None prolog_rules = prolog_result[0]['X'] - rules = [] - for rule in prolog_rules: - rules.append(rule.value) + rules = [rule.value for rule in prolog_rules] return rules - +# TODO:和ClsKB的参数顺序不统一 class RegKB(KBBase): def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) - def logic_forward(self): + def logic_forward(self, pseudo_labels): pass def _find_candidate_GKB(self, pred_res, key): @@ -333,6 +338,7 @@ class HWF_KB(RegKB): ): super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) + # TODO:应该是静态方法 def valid_candidate(self, formula): if len(formula) % 2 == 0: return False