diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 66a134f..cd15dd3 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -25,8 +25,46 @@ import pyswip class KBBase(ABC): - def __init__(self, pseudo_label_list=None): - pass + def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): + self.pseudo_label_list = pseudo_label_list + self.len_list = len_list + self.GKB_flag = GKB_flag + self.max_err = max_err + + if GKB_flag: + self.base = {} + X, Y = self._get_GKB() + for x, y in zip(X, Y): + self.base.setdefault(len(x), defaultdict(list))[y].append(x) + + # For parallel version of _get_GKB + def _get_XY_list(self, args): + pre_x, post_x_it = args[0], args[1] + XY_list = [] + for post_x in post_x_it: + x = (pre_x,) + post_x + y = self.logic_forward(x) + if y != np.inf: + XY_list.append((x, y)) + return XY_list + + # Parallel _get_GKB + def _get_GKB(self): + X, Y = [], [] + for length in self.len_list: + arg_list = [] + for pre_x in self.pseudo_label_list: + post_x_it = product(self.pseudo_label_list, repeat=length - 1) + arg_list.append((pre_x, post_x_it)) + with Pool(processes=len(arg_list)) as pool: + ret_list = pool.map(self._get_XY_list, arg_list) + for XY_list in ret_list: + if len(XY_list) == 0: + continue + part_X, part_Y = zip(*XY_list) + X.extend(part_X) + Y.extend(part_Y) + return X, Y @abstractmethod def logic_forward(self): @@ -87,53 +125,22 @@ class KBBase(ABC): return candidates, min_address_num, address_num + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) + def __len__(self): - pass + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) class ClsKB(KBBase): - def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None): - super().__init__() - self.GKB_flag = GKB_flag - self.pseudo_label_list = pseudo_label_list - self.len_list = len_list - self.max_err = 0 - - if GKB_flag: - self.base = {} - X, Y = self._get_GKB() - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(x) - - - # For parallel version of _get_GKB - def _get_XY_list(self, args): - pre_x, post_x_it = args[0], args[1] - XY_list = [] - for post_x in post_x_it: - x = (pre_x,) + post_x - y = self.logic_forward(x) - if y != np.inf: - XY_list.append((x, y)) - return XY_list - - # Parallel _get_GKB - def _get_GKB(self): - X, Y = [], [] - for length in self.len_list: - arg_list = [] - for pre_x in self.pseudo_label_list: - post_x_it = product(self.pseudo_label_list, repeat=length - 1) - arg_list.append((pre_x, post_x_it)) - with Pool(processes=len(arg_list)) as pool: - ret_list = pool.map(self._get_XY_list, arg_list) - for XY_list in ret_list: - if len(XY_list) == 0: - continue - part_X, part_Y = zip(*XY_list) - X.extend(part_X) - Y.extend(part_Y) - return X, Y + def __init__(self, pseudo_label_list, len_list, GKB_flag): + super().__init__(pseudo_label_list, len_list, GKB_flag) def logic_forward(self): pass @@ -208,22 +215,10 @@ class ClsKB(KBBase): candidates.append(candidate) return candidates - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - class add_KB(ClsKB): - def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]): - super().__init__(GKB_flag, pseudo_label_list, len_list) + def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): + super().__init__(pseudo_label_list, len_list, GKB_flag) def logic_forward(self, nums): return sum(nums) @@ -231,10 +226,8 @@ class add_KB(ClsKB): class prolog_KB(KBBase): def __init__(self, pseudo_label_list): - super().__init__() - self.pseudo_label_list = pseudo_label_list + super().__init__(pseudo_label_list) self.prolog = pyswip.Prolog() - self.max_err = 0 def logic_forward(self): pass @@ -326,46 +319,7 @@ class HED_prolog_KB(prolog_KB): class RegKB(KBBase): def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): - super().__init__() - self.GKB_flag = GKB_flag - self.pseudo_label_list = pseudo_label_list - self.len_list = len_list - self.max_err = max_err - - if GKB_flag: - self.base = {} - X, Y = self._get_GKB() - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(x) - - # For parallel version of _get_GKB - def _get_XY_list(self, args): - pre_x, post_x_it = args[0], args[1] - XY_list = [] - for post_x in post_x_it: - x = (pre_x,) + post_x - y = self.logic_forward(x) - if y != np.inf: - XY_list.append((x, y)) - return XY_list - - # Parallel _get_GKB - def _get_GKB(self): - X, Y = [], [] - for length in self.len_list: - arg_list = [] - for pre_x in self.pseudo_label_list: - post_x_it = product(self.pseudo_label_list, repeat=length - 1) - arg_list.append((pre_x, post_x_it)) - with Pool(processes=len(arg_list)) as pool: - ret_list = pool.map(self._get_XY_list, arg_list) - for XY_list in ret_list: - if len(XY_list) == 0: - continue - part_X, part_Y = zip(*XY_list) - X.extend(part_X) - Y.extend(part_Y) - return X, Y + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) def logic_forward(self): pass @@ -461,18 +415,6 @@ class RegKB(KBBase): candidates.append(candidate) return candidates - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - class HWF_KB(RegKB): def __init__(