Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
ef0298e91a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 13 deletions
  1. +5
    -13
      abducer/kb.py

+ 5
- 13
abducer/kb.py View File

@@ -34,10 +34,6 @@ class KBBase(ABC):
def logic_forward(self):
pass
@abstractmethod
def valid_candidate(self):
pass
def _length(self, length):
if length is None:
length = list(self.base.keys())
@@ -57,6 +53,7 @@ class ClsKB(KBBase):
self.len_list = len_list
if GKB_flag:
# self.base = np.load('abducer/hwf.npy', allow_pickle=True).item()
self.base = {}
X, Y = self.get_GKB(self.pseudo_label_list, self.len_list)
for x, y in zip(X, Y):
@@ -70,14 +67,12 @@ class ClsKB(KBBase):
X = []
Y = []
for x in all_X:
if self.valid_candidate(x):
y = self.logic_forward(x)
if y != np.inf:
X.append(x)
Y.append(self.logic_forward(x))
Y.append(y)
return X, Y
def valid_candidate(self):
pass
def logic_forward(self):
pass

@@ -88,7 +83,7 @@ class ClsKB(KBBase):
if key is None:
return self.get_all_candidates()
if (type(length) is int and length not in self.len_list):
if type(length) is int and length not in self.len_list:
return []
length = self._length(length)
return sum([self.base[l][key] for l in length], [])
@@ -117,9 +112,6 @@ class add_KB(ClsKB):
pseudo_label_list = list(range(10)), \
len_list = [2]):
super().__init__(GKB_flag, pseudo_label_list, len_list)
def valid_candidate(self, x):
return True
def logic_forward(self, nums):
return sum(nums)


Loading…
Cancel
Save