Browse Source

Remove ClsKB and RegKB, add cache_size

pull/3/head
troyyyyy 3 years ago
parent
commit
3a7971dd87
1 changed files with 49 additions and 66 deletions
  1. +49
    -66
      abl/abducer/kb.py

+ 49
- 66
abl/abducer/kb.py View File

@@ -85,9 +85,29 @@ class KBBase(ABC):
else:
return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address)
@abstractmethod
def _find_candidate_GKB(self, pred_res, key):
pass
if self.max_err == 0:
return self.base[len(pred_res)][key]
else:
potential_candidates = self.base[len(pred_res)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, key)
all_candidates = []
for idx in range(key_idx - 1, 0, -1):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
return all_candidates
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address):
if self.base == {}:
@@ -126,33 +146,34 @@ class KBBase(ABC):
new_candidates += candidates
return new_candidates

# TODO:在类初始化时应该有一个cache(默认True)的参数,用户可以指定是否用cache(若KB会变,那不能用cache)
@lru_cache(maxsize=None)
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address):
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key, self.max_err):
candidates.append(pred_res)
else:
@lru_cache(maxsize=self.cache_size)
def _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address):
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
return candidates

return candidates
return _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address)
def _dict_len(self, dic):
if not self.GKB_flag:
return 0
@@ -165,16 +186,7 @@ class KBBase(ABC):
else:
return sum(self._dict_len(v) for v in self.base.values())


class ClsKB(KBBase):
def __init__(self, pseudo_label_list, len_list, GKB_flag):
super().__init__(pseudo_label_list, len_list, GKB_flag)

def _find_candidate_GKB(self, pred_res, key):
return self.base[len(pred_res)][key]


class add_KB(ClsKB):
class add_KB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False):
super().__init__(pseudo_label_list, len_list, GKB_flag)

@@ -215,9 +227,6 @@ class prolog_KB(KBBase):
key_is_none_flag = key is None or (type(key) == list and key[0] is None)
query_string += ",%s)." % key if not key_is_none_flag else ")."
return query_string

def _find_candidate_GKB(self, pred_res, key):
pass
def address_by_idx(self, pred_res, key, address_idx):
candidates = []
@@ -251,33 +260,7 @@ class HED_prolog_KB(prolog_KB):
return rules


class RegKB(KBBase):
def __init__(self, pseudo_label_list, len_list, GKB_flag, max_err):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err)

def _find_candidate_GKB(self, pred_res, key):
potential_candidates = self.base[len(pred_res)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, key)
all_candidates = []
for idx in range(key_idx - 1, 0, -1):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
return all_candidates

class HWF_KB(RegKB):
class HWF_KB(KBBase):
def __init__(
self,
pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],


Loading…
Cancel
Save