From ef0298e91a44c1a9502eb50407b2c9d8987f805f Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 22 Nov 2022 12:34:37 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 8c3f49a..e987d15 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -34,10 +34,6 @@ class KBBase(ABC): def logic_forward(self): pass - @abstractmethod - def valid_candidate(self): - pass - def _length(self, length): if length is None: length = list(self.base.keys()) @@ -57,6 +53,7 @@ class ClsKB(KBBase): self.len_list = len_list if GKB_flag: + # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() self.base = {} X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) for x, y in zip(X, Y): @@ -70,14 +67,12 @@ class ClsKB(KBBase): X = [] Y = [] for x in all_X: - if self.valid_candidate(x): + y = self.logic_forward(x) + if y != np.inf: X.append(x) - Y.append(self.logic_forward(x)) + Y.append(y) return X, Y - def valid_candidate(self): - pass - def logic_forward(self): pass @@ -88,7 +83,7 @@ class ClsKB(KBBase): if key is None: return self.get_all_candidates() - if (type(length) is int and length not in self.len_list): + if type(length) is int and length not in self.len_list: return [] length = self._length(length) return sum([self.base[l][key] for l in length], []) @@ -117,9 +112,6 @@ class add_KB(ClsKB): pseudo_label_list = list(range(10)), \ len_list = [2]): super().__init__(GKB_flag, pseudo_label_list, len_list) - - def valid_candidate(self, x): - return True def logic_forward(self, nums): return sum(nums)