| @@ -80,6 +80,7 @@ class KBBase(ABC): | |||
| res = [self.logic_forward(x) for x in xs] | |||
| return res | |||
| # TODO:这里max_address_num默认值-1,后面运行会有问题吗 | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| if self.GKB_flag: | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| @@ -134,8 +135,10 @@ class KBBase(ABC): | |||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| # TODO:应该也是内部使用的方法? | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| # TODO:product combinations本身就是迭代器,如果没有其他用途,不用转list,直接放到循环那即可,省去一些时间,下面的同理 | |||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||
| if multiple_predictions: | |||
| @@ -209,7 +212,8 @@ class ClsKB(KBBase): | |||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def logic_forward(self): | |||
| # TODO:这里以及RegKB可以不实现logic_forward吗,这样用户继承后不实现logic_forward就会报错 | |||
| def logic_forward(self, pseudo_labels): | |||
| pass | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| @@ -243,13 +247,15 @@ class prolog_KB(KBBase): | |||
| address_pred_res = pred_res.copy() | |||
| if multiple_predictions: | |||
| address_pred_res = flatten(address_pred_res) | |||
| # TODO:可以直接对address_idx循环? | |||
| for idx in range(len(address_pred_res)): | |||
| if idx in address_idx: | |||
| address_pred_res[idx] = 'P' + str(idx) | |||
| if multiple_predictions: | |||
| address_pred_res = reform_idx(address_pred_res, pred_res) | |||
| # TODO:不知道有没有更简洁的方法 | |||
| regex = r"'P\d+'" | |||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) | |||
| @@ -269,6 +275,7 @@ class prolog_KB(KBBase): | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = flatten(pred_res) | |||
| # TODO:这里后面的那个list应该也不需要 | |||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| @@ -289,17 +296,15 @@ class prolog_KB(KBBase): | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]['X'] | |||
| rules = [] | |||
| for rule in prolog_rules: | |||
| rules.append(rule.value) | |||
| rules = [rule.value for rule in prolog_rules] | |||
| return rules | |||
| # TODO:和ClsKB的参数顺序不统一 | |||
| class RegKB(KBBase): | |||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) | |||
| def logic_forward(self): | |||
| def logic_forward(self, pseudo_labels): | |||
| pass | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| @@ -333,6 +338,7 @@ class HWF_KB(RegKB): | |||
| ): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) | |||
| # TODO:应该是静态方法 | |||
| def valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||