Browse Source

update TODO in kb.py

pull/3/head
Tony-HYX 3 years ago
parent
commit
bfd6dc8a5c
1 changed files with 13 additions and 7 deletions
  1. +13
    -7
      abl/abducer/kb.py

+ 13
- 7
abl/abducer/kb.py View File

@@ -80,6 +80,7 @@ class KBBase(ABC):
res = [self.logic_forward(x) for x in xs]
return res

# TODO:这里max_address_num默认值-1,后面运行会有问题吗
def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False):
if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
@@ -134,8 +135,10 @@ class KBBase(ABC):
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates, min_address_num, address_num
# TODO:应该也是内部使用的方法?
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
# TODO:product combinations本身就是迭代器,如果没有其他用途,不用转list,直接放到循环那即可,省去一些时间,下面的同理
abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx)))

if multiple_predictions:
@@ -209,7 +212,8 @@ class ClsKB(KBBase):
def __init__(self, pseudo_label_list, len_list, GKB_flag):
super().__init__(pseudo_label_list, len_list, GKB_flag)

def logic_forward(self):
# TODO:这里以及RegKB可以不实现logic_forward吗,这样用户继承后不实现logic_forward就会报错
def logic_forward(self, pseudo_labels):
pass

def _find_candidate_GKB(self, pred_res, key):
@@ -243,13 +247,15 @@ class prolog_KB(KBBase):
address_pred_res = pred_res.copy()
if multiple_predictions:
address_pred_res = flatten(address_pred_res)
# TODO:可以直接对address_idx循环?
for idx in range(len(address_pred_res)):
if idx in address_idx:
address_pred_res[idx] = 'P' + str(idx)
if multiple_predictions:
address_pred_res = reform_idx(address_pred_res, pred_res)
# TODO:不知道有没有更简洁的方法
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res))
@@ -269,6 +275,7 @@ class prolog_KB(KBBase):
if multiple_predictions:
save_pred_res = pred_res
pred_res = flatten(pred_res)
# TODO:这里后面的那个list应该也不需要
abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))]
for c in abduce_c:
candidate = pred_res.copy()
@@ -289,17 +296,15 @@ class prolog_KB(KBBase):
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]['X']
rules = []
for rule in prolog_rules:
rules.append(rule.value)
rules = [rule.value for rule in prolog_rules]
return rules

# TODO:和ClsKB的参数顺序不统一
class RegKB(KBBase):
def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err)

def logic_forward(self):
def logic_forward(self, pseudo_labels):
pass

def _find_candidate_GKB(self, pred_res, key):
@@ -333,6 +338,7 @@ class HWF_KB(RegKB):
):
super().__init__(GKB_flag, pseudo_label_list, len_list, max_err)

# TODO:应该是静态方法
def valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False


Loading…
Cancel
Save