From 8e8aa76735eeee73cfc2a026ccddf008f99ea531 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 9 Mar 2023 10:54:18 +0800 Subject: [PATCH] Add cache in abduce_by_search --- abl/abducer/kb.py | 39 ++++++++------------------------------- abl/utils/utils.py | 16 ++++++++++++++-- 2 files changed, 22 insertions(+), 33 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 370b790..6cecba4 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -17,7 +17,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal +from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list from multiprocessing import Pool @@ -25,22 +25,17 @@ from functools import lru_cache import pyswip class KBBase(ABC): - def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):#, abduce_cache=True): + def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): self.pseudo_label_list = pseudo_label_list self.len_list = len_list self.GKB_flag = GKB_flag self.max_err = max_err - # self.abduce_cache = abduce_cache if GKB_flag: self.base = {} X, Y = self._get_GKB() for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(x) - - # if abduce_cache: - # self.cache_min_address_num = {} - # self.cache_candidates = {} # For parallel version of _get_GKB def _get_XY_list(self, args): @@ -90,7 +85,7 @@ class KBBase(ABC): if self.GKB_flag: return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) else: - return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) + return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address, multiple_predictions) @abstractmethod def _find_candidate_GKB(self, pred_res, key): @@ -137,27 +132,7 @@ class KBBase(ABC): idxs = np.where(multiple_cost_list <= address_num)[0] candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] return candidates - - # TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好 - # def _get_abduce_cache(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - # if multiple_predictions: - # pred_res = flatten(pred_res) - # key = tuple(key) - # if (tuple(pred_res), key) in self.cache_min_address_num: - # address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) - # if (tuple(pred_res), key, address_num) in self.cache_candidates: - # candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] - # return candidates - # return None - - # def _set_abduce_cache(self, pred_res, key, min_address_num, address_num, candidates, multiple_predictions): - # if multiple_predictions: - # pred_res = flatten(pred_res) - # key = tuple(key) - # self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num - # self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates - - + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) @@ -190,13 +165,15 @@ class KBBase(ABC): new_candidates += candidates return new_candidates - # @lru_cache(maxsize=100) + @lru_cache(maxsize=100) def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): # if self.abduce_cache: # candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions) # if candidates is not None: # return candidates - + pred_res = hashable_to_list(pred_res) + key = hashable_to_list(key) + candidates = [] for address_num in range(len(flatten(pred_res)) + 1): diff --git a/abl/utils/utils.py b/abl/utils/utils.py index d5209a6..90376aa 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -35,7 +35,6 @@ def confidence_dist(A, B): cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) return 1 - np.prod(A[rows, cols, B], axis = 1) - def block_sample(X, Z, Y, sample_num, epoch_idx): part_num = len(X) // sample_num if part_num == 0: @@ -48,7 +47,6 @@ def block_sample(X, Z, Y, sample_num, epoch_idx): return X, Z, Y - def gen_mappings(chars, symbs): n_char = len(chars) n_symbs = len(symbs) @@ -86,3 +84,17 @@ def check_equal(a, b, max_err=0): else: return a == b + +def to_hashable(l): + if type(l) is not list: + return l + if type(l[0]) is not list: + return tuple(l) + return tuple(tuple(sublist) for sublist in l) + +def hashable_to_list(t): + if type(t) is not tuple: + return t + if type(t[0]) is not tuple: + return list(t) + return [list(subtuple) for subtuple in t] \ No newline at end of file