Browse Source

Modify lru_cache

pull/3/head
troyyyyy 3 years ago
parent
commit
d60b4f7c9c
1 changed files with 34 additions and 27 deletions
  1. +34
    -27
      abl/abducer/kb.py

+ 34
- 27
abl/abducer/kb.py View File

@@ -17,7 +17,7 @@ import numpy as np

from collections import defaultdict
from itertools import product, combinations
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list
from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool

@@ -25,7 +25,7 @@ from functools import lru_cache
import pyswip

class KBBase(ABC):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, cache_size=128):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, use_cache=True):
# TODO:添加一下类型检查,比如
# if not isinstance(X, (np.ndarray, spmatrix)):
# raise TypeError("X should be numpy array or sparse matrix")
@@ -34,7 +34,7 @@ class KBBase(ABC):
self.len_list = len_list
self.GKB_flag = GKB_flag
self.max_err = max_err
self.cache_size = cache_size
self.use_cache = use_cache

if GKB_flag:
self.base = {}
@@ -147,32 +147,39 @@ class KBBase(ABC):
return new_candidates

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address):
@lru_cache(maxsize=self.cache_size)
def _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address):
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates
if self.use_cache:
return self._abduce_by_search_cache(pred_res, key, max_address_num, require_more_address)
else:
return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address)
@lru_cache(maxsize=None)
def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address):
return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address)
def _abduce_by_search_no_cache(self, pred_res, key, max_address_num, require_more_address):
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
return candidates
return _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address)
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
return candidates
def _dict_len(self, dic):
if not self.GKB_flag:


Loading…
Cancel
Save