| @@ -17,7 +17,7 @@ import numpy as np | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||
| from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||
| from multiprocessing import Pool | |||
| @@ -25,7 +25,7 @@ from functools import lru_cache | |||
| import pyswip | |||
| class KBBase(ABC): | |||
| def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, cache_size=128): | |||
| def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, use_cache=True): | |||
| # TODO:添加一下类型检查,比如 | |||
| # if not isinstance(X, (np.ndarray, spmatrix)): | |||
| # raise TypeError("X should be numpy array or sparse matrix") | |||
| @@ -34,7 +34,7 @@ class KBBase(ABC): | |||
| self.len_list = len_list | |||
| self.GKB_flag = GKB_flag | |||
| self.max_err = max_err | |||
| self.cache_size = cache_size | |||
| self.use_cache = use_cache | |||
| if GKB_flag: | |||
| self.base = {} | |||
| @@ -147,32 +147,39 @@ class KBBase(ABC): | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address): | |||
| @lru_cache(maxsize=self.cache_size) | |||
| def _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address): | |||
| pred_res = hashable_to_list(pred_res) | |||
| key = hashable_to_list(key) | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [] | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates | |||
| if self.use_cache: | |||
| return self._abduce_by_search_cache(pred_res, key, max_address_num, require_more_address) | |||
| else: | |||
| return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) | |||
| @lru_cache(maxsize=None) | |||
| def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address): | |||
| return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) | |||
| def _abduce_by_search_no_cache(self, pred_res, key, max_address_num, require_more_address): | |||
| pred_res = hashable_to_list(pred_res) | |||
| key = hashable_to_list(key) | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates | |||
| return _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address) | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [] | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates | |||
| new_candidates = self._address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||