| @@ -4,6 +4,7 @@ import numpy as np | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||
| from multiprocessing import Pool | |||
| @@ -56,10 +57,8 @@ class KBBase(ABC): | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| if type(Y[0]) in (int, float): | |||
| sorted_XY = sorted(list(zip(Y, X))) | |||
| X = [x for y, x in sorted_XY] | |||
| Y = [y for y, x in sorted_XY] | |||
| if Y and isinstance(Y[0], (int, float)): | |||
| X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||
| return X, Y | |||
| @abstractmethod | |||
| @@ -100,21 +99,19 @@ class KBBase(ABC): | |||
| return all_candidates | |||
| def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): | |||
| if self.base == {}: | |||
| if self.base == {} or len(pred_res) not in self.len_list: | |||
| return [] | |||
| if len(pred_res) not in self.len_list: | |||
| return [] | |||
| all_candidates = self._find_candidate_GKB(pred_res, y) | |||
| if len(all_candidates) == 0: | |||
| return [] | |||
| else: | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_revision_num = np.min(cost_list) | |||
| revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||
| idxs = np.where(cost_list <= revision_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_revision_num = np.min(cost_list) | |||
| revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||
| idxs = np.where(cost_list <= revision_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates | |||
| def revise_by_idx(self, pred_res, y, revision_idx): | |||
| candidates = [] | |||
| @@ -129,22 +126,20 @@ class KBBase(ABC): | |||
| def _revision(self, revision_num, pred_res, y): | |||
| new_candidates = [] | |||
| revision_idx_list = combinations(list(range(len(pred_res))), revision_num) | |||
| revision_idx_list = combinations(range(len(pred_res)), revision_num) | |||
| for revision_idx in revision_idx_list: | |||
| candidates = self.revise_by_idx(pred_res, y, revision_idx) | |||
| new_candidates += candidates | |||
| new_candidates.extend(candidates) | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): | |||
| candidates = [] | |||
| for revision_num in range(len(pred_res) + 1): | |||
| if revision_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), y, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._revision(revision_num, pred_res, y) | |||
| candidates += new_candidates | |||
| if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): | |||
| candidates.append(pred_res) | |||
| elif revision_num > 0: | |||
| candidates += self._revision(revision_num, pred_res, y) | |||
| if len(candidates) > 0: | |||
| min_revision_num = revision_num | |||
| break | |||
| @@ -154,8 +149,7 @@ class KBBase(ABC): | |||
| for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): | |||
| if revision_num > max_revision_num: | |||
| return candidates | |||
| new_candidates = self._revision(revision_num, pred_res, y) | |||
| candidates += new_candidates | |||
| candidates += self._revision(revision_num, pred_res, y) | |||
| return candidates | |||
| @lru_cache(maxsize=None) | |||