| @@ -48,15 +48,19 @@ class KBBase(ABC): | |||||
| address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | ||||
| for address_idx in address_idx_list: | for address_idx in address_idx_list: | ||||
| # TODO: 要么address_by_idx放这个类,要么定义abstractmethod | |||||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | ||||
| new_candidates += candidates | new_candidates += candidates | ||||
| return new_candidates | return new_candidates | ||||
| def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions=False): | |||||
| def abduction(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): | |||||
| candidates = [] | candidates = [] | ||||
| # TODO: 这里的len(pred_res)考虑了multiple_predictions吗? | |||||
| for address_num in range(len(pred_res) + 1): | for address_num in range(len(pred_res) + 1): | ||||
| if address_num == 0: | if address_num == 0: | ||||
| # TODO: 不是所有的key都是数字,也可以是字符串,甚至列表? | |||||
| # TODO: check type (str int float multiple_pred ...) | |||||
| if abs(self.logic_forward(pred_res) - key) <= 1e-3: | if abs(self.logic_forward(pred_res) - key) <= 1e-3: | ||||
| candidates.append(pred_res) | candidates.append(pred_res) | ||||
| else: | else: | ||||
| @@ -88,7 +92,8 @@ class ClsKB(KBBase): | |||||
| self.GKB_flag = GKB_flag | self.GKB_flag = GKB_flag | ||||
| self.pseudo_label_list = pseudo_label_list | self.pseudo_label_list = pseudo_label_list | ||||
| self.len_list = len_list | self.len_list = len_list | ||||
| self.prolog_flag = False | |||||
| self.prolog_flag = False # TODO:没用? | |||||
| # TODO: 既然pseudo_label_list len_list prolog_flag都存为self了,那为啥又传到后面的方法里,或者说self就没用上 | |||||
| if GKB_flag: | if GKB_flag: | ||||
| self.base = {} | self.base = {} | ||||
| @@ -143,6 +148,7 @@ class ClsKB(KBBase): | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| # TODO: abduction和abduce_candidates命名上太相近了,即使这样,也有其中一个是私有方法加"_"吧 | |||||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | ||||
| if self.GKB_flag: | if self.GKB_flag: | ||||
| return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | ||||