diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index bd76831..cf14bb7 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -64,7 +64,7 @@ class AbducerBase(abc.ABC): score = 0 for idx in range(len(pred_res)): address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) + candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) if len(candidate) > 0: score += 1 return score @@ -72,7 +72,7 @@ class AbducerBase(abc.ABC): def _zoopt_address_score(self, pred_res, key, sol): if not self.multiple_predictions: address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.address_by_idx(pred_res, key, address_idx) return 1 if len(candidates) > 0 else 0 else: return self._zoopt_score_multiple(pred_res, key, sol.get_x()) @@ -115,6 +115,9 @@ class AbducerBase(abc.ABC): key = tuple(key) self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates + + def address_by_idx(self, pred_res, key, address_idx): + return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) def abduce(self, data, max_address_num=-1, require_more_address=0): pred_res, pred_res_prob, key = data @@ -129,7 +132,7 @@ class AbducerBase(abc.ABC): if self.zoopt: solution = self.zoopt_get_solution(pred_res, key, max_address_num) address_idx = [idx for idx, i in enumerate(solution) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.address_by_idx(pred_res, key, address_idx) address_num = int(solution.sum()) min_address_num = address_num else: @@ -156,7 +159,6 @@ class AbducerBase(abc.ABC): def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) - if __name__ == '__main__': from kb import add_KB, prolog_KB, HWF_KB diff --git a/abl/framework_hed.py b/abl/framework_hed.py index a823b13..3f09ff6 100644 --- a/abl/framework_hed.py +++ b/abl/framework_hed.py @@ -158,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for idx in range(len(pred_res)): address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = abducer.kb.address_by_idx([pred_res[idx]], None, address_idx, True) + candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx) if len(candidate) > 0: consistent_idx_tmp.append(idx) consistent_pred_res_tmp.append(candidate[0][0])