|
|
|
@@ -17,7 +17,7 @@ import numpy as np |
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
|
from itertools import product, combinations |
|
|
|
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal |
|
|
|
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list |
|
|
|
|
|
|
|
from multiprocessing import Pool |
|
|
|
|
|
|
|
@@ -25,22 +25,17 @@ from functools import lru_cache |
|
|
|
import pyswip |
|
|
|
|
|
|
|
class KBBase(ABC): |
|
|
|
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):#, abduce_cache=True): |
|
|
|
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): |
|
|
|
self.pseudo_label_list = pseudo_label_list |
|
|
|
self.len_list = len_list |
|
|
|
self.GKB_flag = GKB_flag |
|
|
|
self.max_err = max_err |
|
|
|
# self.abduce_cache = abduce_cache |
|
|
|
|
|
|
|
if GKB_flag: |
|
|
|
self.base = {} |
|
|
|
X, Y = self._get_GKB() |
|
|
|
for x, y in zip(X, Y): |
|
|
|
self.base.setdefault(len(x), defaultdict(list))[y].append(x) |
|
|
|
|
|
|
|
# if abduce_cache: |
|
|
|
# self.cache_min_address_num = {} |
|
|
|
# self.cache_candidates = {} |
|
|
|
|
|
|
|
# For parallel version of _get_GKB |
|
|
|
def _get_XY_list(self, args): |
|
|
|
@@ -90,7 +85,7 @@ class KBBase(ABC): |
|
|
|
if self.GKB_flag: |
|
|
|
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) |
|
|
|
else: |
|
|
|
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) |
|
|
|
return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address, multiple_predictions) |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def _find_candidate_GKB(self, pred_res, key): |
|
|
|
@@ -137,27 +132,7 @@ class KBBase(ABC): |
|
|
|
idxs = np.where(multiple_cost_list <= address_num)[0] |
|
|
|
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] |
|
|
|
return candidates |
|
|
|
|
|
|
|
# TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好 |
|
|
|
# def _get_abduce_cache(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): |
|
|
|
# if multiple_predictions: |
|
|
|
# pred_res = flatten(pred_res) |
|
|
|
# key = tuple(key) |
|
|
|
# if (tuple(pred_res), key) in self.cache_min_address_num: |
|
|
|
# address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) |
|
|
|
# if (tuple(pred_res), key, address_num) in self.cache_candidates: |
|
|
|
# candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] |
|
|
|
# return candidates |
|
|
|
# return None |
|
|
|
|
|
|
|
# def _set_abduce_cache(self, pred_res, key, min_address_num, address_num, candidates, multiple_predictions): |
|
|
|
# if multiple_predictions: |
|
|
|
# pred_res = flatten(pred_res) |
|
|
|
# key = tuple(key) |
|
|
|
# self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num |
|
|
|
# self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): |
|
|
|
candidates = [] |
|
|
|
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) |
|
|
|
@@ -190,13 +165,15 @@ class KBBase(ABC): |
|
|
|
new_candidates += candidates |
|
|
|
return new_candidates |
|
|
|
|
|
|
|
# @lru_cache(maxsize=100) |
|
|
|
@lru_cache(maxsize=100) |
|
|
|
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): |
|
|
|
# if self.abduce_cache: |
|
|
|
# candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions) |
|
|
|
# if candidates is not None: |
|
|
|
# return candidates |
|
|
|
|
|
|
|
pred_res = hashable_to_list(pred_res) |
|
|
|
key = hashable_to_list(key) |
|
|
|
|
|
|
|
candidates = [] |
|
|
|
|
|
|
|
for address_num in range(len(flatten(pred_res)) + 1): |
|
|
|
|