Browse Source

Fix a minor bug

pull/3/head
troyyyyy 3 years ago
parent
commit
e735693610
1 changed files with 1 additions and 16 deletions
  1. +1
    -16
      abl/abducer/kb.py

+ 1
- 16
abl/abducer/kb.py View File

@@ -113,7 +113,6 @@ class KBBase(ABC):
min_address_num = 0
all_candidates_save = []
cost_list_save = []
for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list:
return []
@@ -136,7 +135,6 @@ class KBBase(ABC):
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))

if multiple_predictions:
save_pred_res = pred_res
pred_res = flatten(pred_res)
@@ -145,10 +143,8 @@ class KBBase(ABC):
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]

if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res)

if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err):
candidates.append(candidate)
return candidates
@@ -167,15 +163,10 @@ class KBBase(ABC):

@lru_cache(maxsize=100)
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
# if self.abduce_cache:
# candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions)
# if candidates is not None:
# return candidates
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err):
@@ -183,23 +174,17 @@ class KBBase(ABC):
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

if len(candidates) > 0:
min_address_num = address_num
break

if address_num >= max_address_num:
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
return candidates
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates
# if self.abduce_cache:
# self._set_abduce_cache(pred_res, key, min_address_num, address_num, candidates, multiple_predictions)

return candidates

def _dict_len(self, dic):


Loading…
Cancel
Save