| @@ -115,7 +115,7 @@ class add_KB(KBBase): | |||||
| def __len__(self): | def __len__(self): | ||||
| return sum(self._dict_len(v) for v in self.base.values()) | return sum(self._dict_len(v) for v in self.base.values()) | ||||
| class ClsKB(KBBase): | |||||
| class cls_KB(KBBase): | |||||
| def __init__(self, X, Y = None): | def __init__(self, X, Y = None): | ||||
| super().__init__() | super().__init__() | ||||
| self.base = {} | self.base = {} | ||||
| @@ -128,6 +128,9 @@ class ClsKB(KBBase): | |||||
| for x, y in zip(X, Y): | for x, y in zip(X, Y): | ||||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | ||||
| def logic_forward(self): | |||||
| return None | |||||
| def get_candidates(self, key, length = None): | def get_candidates(self, key, length = None): | ||||
| if key is None: | if key is None: | ||||
| @@ -146,7 +149,7 @@ class ClsKB(KBBase): | |||||
| def __len__(self): | def __len__(self): | ||||
| return sum(self._dict_len(v) for v in self.base.values()) | return sum(self._dict_len(v) for v in self.base.values()) | ||||
| class RegKB(KBBase): | |||||
| class reg_KB(KBBase): | |||||
| def __init__(self, X, Y = None): | def __init__(self, X, Y = None): | ||||
| super().__init__() | super().__init__() | ||||
| tmp_dict = {} | tmp_dict = {} | ||||
| @@ -159,7 +162,10 @@ class RegKB(KBBase): | |||||
| X = [x for y, x in data] | X = [x for y, x in data] | ||||
| Y = [y for y, x in data] | Y = [y for y, x in data] | ||||
| self.base[l] = (X, Y) | self.base[l] = (X, Y) | ||||
| def logic_forward(self): | |||||
| return None | |||||
| def get_candidates(self, key, length = None): | def get_candidates(self, key, length = None): | ||||
| if key is None: | if key is None: | ||||
| return self.get_all_candidates() | return self.get_all_candidates() | ||||
| @@ -207,7 +213,7 @@ if __name__ == "__main__": | |||||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | ||||
| Y = [2, 1, 1, 2, 2] | Y = [2, 1, 1, 2, 2] | ||||
| kb = ClsKB(X, Y) | |||||
| kb = cls_KB(X, Y) | |||||
| print(len(kb)) | print(len(kb)) | ||||
| res = kb.get_candidates(2, 5) | res = kb.get_candidates(2, 5) | ||||
| print(res) | print(res) | ||||
| @@ -219,7 +225,7 @@ if __name__ == "__main__": | |||||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | ||||
| Y = [2, 1, 1, 2, 1.5, 1.5] | Y = [2, 1, 1, 2, 1.5, 1.5] | ||||
| kb = RegKB(X, Y) | |||||
| kb = reg_KB(X, Y) | |||||
| print(len(kb)) | print(len(kb)) | ||||
| res = kb.get_candidates(1.6) | res = kb.get_candidates(1.6) | ||||
| print(res) | print(res) | ||||