|
|
|
@@ -48,15 +48,19 @@ class KBBase(ABC): |
|
|
|
address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) |
|
|
|
|
|
|
|
for address_idx in address_idx_list: |
|
|
|
# TODO: 要么address_by_idx放这个类,要么定义abstractmethod |
|
|
|
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) |
|
|
|
new_candidates += candidates |
|
|
|
return new_candidates |
|
|
|
|
|
|
|
def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions=False): |
|
|
|
def abduction(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): |
|
|
|
candidates = [] |
|
|
|
|
|
|
|
# TODO: 这里的len(pred_res)考虑了multiple_predictions吗? |
|
|
|
for address_num in range(len(pred_res) + 1): |
|
|
|
if address_num == 0: |
|
|
|
# TODO: 不是所有的key都是数字,也可以是字符串,甚至列表? |
|
|
|
# TODO: check type (str int float multiple_pred ...) |
|
|
|
if abs(self.logic_forward(pred_res) - key) <= 1e-3: |
|
|
|
candidates.append(pred_res) |
|
|
|
else: |
|
|
|
@@ -88,7 +92,8 @@ class ClsKB(KBBase): |
|
|
|
self.GKB_flag = GKB_flag |
|
|
|
self.pseudo_label_list = pseudo_label_list |
|
|
|
self.len_list = len_list |
|
|
|
self.prolog_flag = False |
|
|
|
self.prolog_flag = False # TODO:没用? |
|
|
|
# TODO: 既然pseudo_label_list len_list prolog_flag都存为self了,那为啥又传到后面的方法里,或者说self就没用上 |
|
|
|
|
|
|
|
if GKB_flag: |
|
|
|
self.base = {} |
|
|
|
@@ -143,6 +148,7 @@ class ClsKB(KBBase): |
|
|
|
def logic_forward(self): |
|
|
|
pass |
|
|
|
|
|
|
|
# TODO: abduction和abduce_candidates命名上太相近了,即使这样,也有其中一个是私有方法加"_"吧 |
|
|
|
def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): |
|
|
|
if self.GKB_flag: |
|
|
|
return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) |
|
|
|
|