Browse Source

Modify cache option

pull/3/head
troyyyyy 3 years ago
parent
commit
97edd1a634
1 changed files with 17 additions and 18 deletions
  1. +17
    -18
      abl/abducer/kb.py

+ 17
- 18
abl/abducer/kb.py View File

@@ -17,7 +17,7 @@ import numpy as np

from collections import defaultdict
from itertools import product, combinations
from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool

@@ -83,7 +83,10 @@ class KBBase(ABC):
if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address)
else:
return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address)
if not self.use_cache:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address)
else:
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address)
def _find_candidate_GKB(self, pred_res, key):
if self.max_err == 0:
@@ -148,17 +151,8 @@ class KBBase(ABC):

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address):
if self.use_cache:
return self._abduce_by_search_cache(pred_res, key, max_address_num, require_more_address)
else:
return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address)
@lru_cache(maxsize=None)
def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address):
return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address)
def _abduce_by_search_no_cache(self, pred_res, key, max_address_num, require_more_address):
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []
for address_num in range(len(pred_res) + 1):
@@ -180,7 +174,11 @@ class KBBase(ABC):
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
return candidates
@lru_cache(maxsize=None)
def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address):
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address)
def _dict_len(self, dic):
if not self.GKB_flag:
return 0
@@ -194,8 +192,8 @@ class KBBase(ABC):
return sum(self._dict_len(v) for v in self.base.values())

class add_KB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False):
super().__init__(pseudo_label_list, len_list, GKB_flag)
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, use_cache=True):
super().__init__(pseudo_label_list, len_list, GKB_flag, use_cache)

def logic_forward(self, nums):
return sum(nums)
@@ -273,9 +271,10 @@ class HWF_KB(KBBase):
pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],
len_list=[1, 3, 5, 7],
GKB_flag=False,
max_err=1e-3
max_err=1e-3,
use_cache=True
):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err)
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:


Loading…
Cancel
Save