Browse Source

[MNT] Enhance code

pull/3/head
Tony-HYX 2 years ago
parent
commit
e2f57e7698
1 changed files with 18 additions and 24 deletions
  1. +18
    -24
      abl/reasoning/kb.py

+ 18
- 24
abl/reasoning/kb.py View File

@@ -4,6 +4,7 @@ import numpy as np

from collections import defaultdict
from itertools import product, combinations

from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool
@@ -56,10 +57,8 @@ class KBBase(ABC):
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
if type(Y[0]) in (int, float):
sorted_XY = sorted(list(zip(Y, X)))
X = [x for y, x in sorted_XY]
Y = [y for y, x in sorted_XY]
if Y and isinstance(Y[0], (int, float)):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
return X, Y

@abstractmethod
@@ -100,21 +99,19 @@ class KBBase(ABC):
return all_candidates
def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision):
if self.base == {}:
if self.base == {} or len(pred_res) not in self.len_list:
return []
if len(pred_res) not in self.len_list:
return []
all_candidates = self._find_candidate_GKB(pred_res, y)
if len(all_candidates) == 0:
return []
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates
cost_list = hamming_dist(pred_res, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates

def revise_by_idx(self, pred_res, y, revision_idx):
candidates = []
@@ -129,22 +126,20 @@ class KBBase(ABC):

def _revision(self, revision_num, pred_res, y):
new_candidates = []
revision_idx_list = combinations(list(range(len(pred_res))), revision_num)
revision_idx_list = combinations(range(len(pred_res)), revision_num)

for revision_idx in revision_idx_list:
candidates = self.revise_by_idx(pred_res, y, revision_idx)
new_candidates += candidates
new_candidates.extend(candidates)
return new_candidates

def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision):
candidates = []
for revision_num in range(len(pred_res) + 1):
if revision_num == 0:
if check_equal(self.logic_forward(pred_res), y, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._revision(revision_num, pred_res, y)
candidates += new_candidates
if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err):
candidates.append(pred_res)
elif revision_num > 0:
candidates += self._revision(revision_num, pred_res, y)
if len(candidates) > 0:
min_revision_num = revision_num
break
@@ -154,8 +149,7 @@ class KBBase(ABC):
for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1):
if revision_num > max_revision_num:
return candidates
new_candidates = self._revision(revision_num, pred_res, y)
candidates += new_candidates
candidates += self._revision(revision_num, pred_res, y)
return candidates
@lru_cache(maxsize=None)


Loading…
Cancel
Save