|
|
|
@@ -34,10 +34,6 @@ class KBBase(ABC): |
|
|
|
def logic_forward(self): |
|
|
|
pass |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def valid_candidate(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def _length(self, length): |
|
|
|
if length is None: |
|
|
|
length = list(self.base.keys()) |
|
|
|
@@ -57,6 +53,7 @@ class ClsKB(KBBase): |
|
|
|
self.len_list = len_list |
|
|
|
|
|
|
|
if GKB_flag: |
|
|
|
# self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() |
|
|
|
self.base = {} |
|
|
|
X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) |
|
|
|
for x, y in zip(X, Y): |
|
|
|
@@ -70,14 +67,12 @@ class ClsKB(KBBase): |
|
|
|
X = [] |
|
|
|
Y = [] |
|
|
|
for x in all_X: |
|
|
|
if self.valid_candidate(x): |
|
|
|
y = self.logic_forward(x) |
|
|
|
if y != np.inf: |
|
|
|
X.append(x) |
|
|
|
Y.append(self.logic_forward(x)) |
|
|
|
Y.append(y) |
|
|
|
return X, Y |
|
|
|
|
|
|
|
def valid_candidate(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def logic_forward(self): |
|
|
|
pass |
|
|
|
|
|
|
|
@@ -88,7 +83,7 @@ class ClsKB(KBBase): |
|
|
|
if key is None: |
|
|
|
return self.get_all_candidates() |
|
|
|
|
|
|
|
if (type(length) is int and length not in self.len_list): |
|
|
|
if type(length) is int and length not in self.len_list: |
|
|
|
return [] |
|
|
|
length = self._length(length) |
|
|
|
return sum([self.base[l][key] for l in length], []) |
|
|
|
@@ -117,9 +112,6 @@ class add_KB(ClsKB): |
|
|
|
pseudo_label_list = list(range(10)), \ |
|
|
|
len_list = [2]): |
|
|
|
super().__init__(GKB_flag, pseudo_label_list, len_list) |
|
|
|
|
|
|
|
def valid_candidate(self, x): |
|
|
|
return True |
|
|
|
|
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|