Browse Source

Add cache in abduce_by_search

pull/3/head
troyyyyy 2 years ago
parent
commit
8e8aa76735
2 changed files with 22 additions and 33 deletions
  1. +8
    -31
      abl/abducer/kb.py
  2. +14
    -2
      abl/utils/utils.py

+ 8
- 31
abl/abducer/kb.py View File

@@ -17,7 +17,7 @@ import numpy as np

from collections import defaultdict
from itertools import product, combinations
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool

@@ -25,22 +25,17 @@ from functools import lru_cache
import pyswip

class KBBase(ABC):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):#, abduce_cache=True):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.GKB_flag = GKB_flag
self.max_err = max_err
# self.abduce_cache = abduce_cache

if GKB_flag:
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
# if abduce_cache:
# self.cache_min_address_num = {}
# self.cache_candidates = {}

# For parallel version of _get_GKB
def _get_XY_list(self, args):
@@ -90,7 +85,7 @@ class KBBase(ABC):
if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
else:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)
return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address, multiple_predictions)
@abstractmethod
def _find_candidate_GKB(self, pred_res, key):
@@ -137,27 +132,7 @@ class KBBase(ABC):
idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates
# TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好
# def _get_abduce_cache(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
# if multiple_predictions:
# pred_res = flatten(pred_res)
# key = tuple(key)
# if (tuple(pred_res), key) in self.cache_min_address_num:
# address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address)
# if (tuple(pred_res), key, address_num) in self.cache_candidates:
# candidates = self.cache_candidates[(tuple(pred_res), key, address_num)]
# return candidates
# return None

# def _set_abduce_cache(self, pred_res, key, min_address_num, address_num, candidates, multiple_predictions):
# if multiple_predictions:
# pred_res = flatten(pred_res)
# key = tuple(key)
# self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num
# self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates

def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))
@@ -190,13 +165,15 @@ class KBBase(ABC):
new_candidates += candidates
return new_candidates

# @lru_cache(maxsize=100)
@lru_cache(maxsize=100)
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
# if self.abduce_cache:
# candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions)
# if candidates is not None:
# return candidates

pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):


+ 14
- 2
abl/utils/utils.py View File

@@ -35,7 +35,6 @@ def confidence_dist(A, B):
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
return 1 - np.prod(A[rows, cols, B], axis = 1)


def block_sample(X, Z, Y, sample_num, epoch_idx):
part_num = len(X) // sample_num
if part_num == 0:
@@ -48,7 +47,6 @@ def block_sample(X, Z, Y, sample_num, epoch_idx):

return X, Z, Y


def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
@@ -86,3 +84,17 @@ def check_equal(a, b, max_err=0):
else:
return a == b

def to_hashable(l):
if type(l) is not list:
return l
if type(l[0]) is not list:
return tuple(l)
return tuple(tuple(sublist) for sublist in l)

def hashable_to_list(t):
if type(t) is not tuple:
return t
if type(t[0]) is not tuple:
return list(t)
return [list(subtuple) for subtuple in t]

Loading…
Cancel
Save