From e2f57e769834a1cc0b89c9ecec778c0001e39879 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Tue, 5 Sep 2023 20:35:08 +0800 Subject: [PATCH] [MNT] Enhance code --- abl/reasoning/kb.py | 42 ++++++++++++++++++------------------------ 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 4fd60f9..9f99a60 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -4,6 +4,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations + from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list from multiprocessing import Pool @@ -56,10 +57,8 @@ class KBBase(ABC): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) - if type(Y[0]) in (int, float): - sorted_XY = sorted(list(zip(Y, X))) - X = [x for y, x in sorted_XY] - Y = [y for y, x in sorted_XY] + if Y and isinstance(Y[0], (int, float)): + X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) return X, Y @abstractmethod @@ -100,21 +99,19 @@ class KBBase(ABC): return all_candidates def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): - if self.base == {}: + if self.base == {} or len(pred_res) not in self.len_list: return [] - if len(pred_res) not in self.len_list: - return [] all_candidates = self._find_candidate_GKB(pred_res, y) if len(all_candidates) == 0: return [] - else: - cost_list = hamming_dist(pred_res, all_candidates) - min_revision_num = np.min(cost_list) - revision_num = min(max_revision_num, min_revision_num + require_more_revision) - idxs = np.where(cost_list <= revision_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - return candidates + + cost_list = hamming_dist(pred_res, all_candidates) + min_revision_num = np.min(cost_list) + revision_num = min(max_revision_num, min_revision_num + require_more_revision) + idxs = np.where(cost_list <= revision_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates def revise_by_idx(self, pred_res, y, revision_idx): candidates = [] @@ -129,22 +126,20 @@ class KBBase(ABC): def _revision(self, revision_num, pred_res, y): new_candidates = [] - revision_idx_list = combinations(list(range(len(pred_res))), revision_num) + revision_idx_list = combinations(range(len(pred_res)), revision_num) for revision_idx in revision_idx_list: candidates = self.revise_by_idx(pred_res, y, revision_idx) - new_candidates += candidates + new_candidates.extend(candidates) return new_candidates def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): candidates = [] for revision_num in range(len(pred_res) + 1): - if revision_num == 0: - if check_equal(self.logic_forward(pred_res), y, self.max_err): - candidates.append(pred_res) - else: - new_candidates = self._revision(revision_num, pred_res, y) - candidates += new_candidates + if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): + candidates.append(pred_res) + elif revision_num > 0: + candidates += self._revision(revision_num, pred_res, y) if len(candidates) > 0: min_revision_num = revision_num break @@ -154,8 +149,7 @@ class KBBase(ABC): for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): if revision_num > max_revision_num: return candidates - new_candidates = self._revision(revision_num, pred_res, y) - candidates += new_candidates + candidates += self._revision(revision_num, pred_res, y) return candidates @lru_cache(maxsize=None)