| @@ -25,8 +25,46 @@ import pyswip | |||||
| class KBBase(ABC): | class KBBase(ABC): | ||||
| def __init__(self, pseudo_label_list=None): | |||||
| pass | |||||
| def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): | |||||
| self.pseudo_label_list = pseudo_label_list | |||||
| self.len_list = len_list | |||||
| self.GKB_flag = GKB_flag | |||||
| self.max_err = max_err | |||||
| if GKB_flag: | |||||
| self.base = {} | |||||
| X, Y = self._get_GKB() | |||||
| for x, y in zip(X, Y): | |||||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
| # For parallel version of _get_GKB | |||||
| def _get_XY_list(self, args): | |||||
| pre_x, post_x_it = args[0], args[1] | |||||
| XY_list = [] | |||||
| for post_x in post_x_it: | |||||
| x = (pre_x,) + post_x | |||||
| y = self.logic_forward(x) | |||||
| if y != np.inf: | |||||
| XY_list.append((x, y)) | |||||
| return XY_list | |||||
| # Parallel _get_GKB | |||||
| def _get_GKB(self): | |||||
| X, Y = [], [] | |||||
| for length in self.len_list: | |||||
| arg_list = [] | |||||
| for pre_x in self.pseudo_label_list: | |||||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||||
| arg_list.append((pre_x, post_x_it)) | |||||
| with Pool(processes=len(arg_list)) as pool: | |||||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||||
| for XY_list in ret_list: | |||||
| if len(XY_list) == 0: | |||||
| continue | |||||
| part_X, part_Y = zip(*XY_list) | |||||
| X.extend(part_X) | |||||
| Y.extend(part_Y) | |||||
| return X, Y | |||||
| @abstractmethod | @abstractmethod | ||||
| def logic_forward(self): | def logic_forward(self): | ||||
| @@ -87,53 +125,22 @@ class KBBase(ABC): | |||||
| return candidates, min_address_num, address_num | return candidates, min_address_num, address_num | ||||
| def _dict_len(self, dic): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(len(c) for c in dic.values()) | |||||
| def __len__(self): | def __len__(self): | ||||
| pass | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(self._dict_len(v) for v in self.base.values()) | |||||
| class ClsKB(KBBase): | class ClsKB(KBBase): | ||||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None): | |||||
| super().__init__() | |||||
| self.GKB_flag = GKB_flag | |||||
| self.pseudo_label_list = pseudo_label_list | |||||
| self.len_list = len_list | |||||
| self.max_err = 0 | |||||
| if GKB_flag: | |||||
| self.base = {} | |||||
| X, Y = self._get_GKB() | |||||
| for x, y in zip(X, Y): | |||||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
| # For parallel version of _get_GKB | |||||
| def _get_XY_list(self, args): | |||||
| pre_x, post_x_it = args[0], args[1] | |||||
| XY_list = [] | |||||
| for post_x in post_x_it: | |||||
| x = (pre_x,) + post_x | |||||
| y = self.logic_forward(x) | |||||
| if y != np.inf: | |||||
| XY_list.append((x, y)) | |||||
| return XY_list | |||||
| # Parallel _get_GKB | |||||
| def _get_GKB(self): | |||||
| X, Y = [], [] | |||||
| for length in self.len_list: | |||||
| arg_list = [] | |||||
| for pre_x in self.pseudo_label_list: | |||||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||||
| arg_list.append((pre_x, post_x_it)) | |||||
| with Pool(processes=len(arg_list)) as pool: | |||||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||||
| for XY_list in ret_list: | |||||
| if len(XY_list) == 0: | |||||
| continue | |||||
| part_X, part_Y = zip(*XY_list) | |||||
| X.extend(part_X) | |||||
| Y.extend(part_Y) | |||||
| return X, Y | |||||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| @@ -208,22 +215,10 @@ class ClsKB(KBBase): | |||||
| candidates.append(candidate) | candidates.append(candidate) | ||||
| return candidates | return candidates | ||||
| def _dict_len(self, dic): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(len(c) for c in dic.values()) | |||||
| def __len__(self): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(self._dict_len(v) for v in self.base.values()) | |||||
| class add_KB(ClsKB): | class add_KB(ClsKB): | ||||
| def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]): | |||||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | |||||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| @@ -231,10 +226,8 @@ class add_KB(ClsKB): | |||||
| class prolog_KB(KBBase): | class prolog_KB(KBBase): | ||||
| def __init__(self, pseudo_label_list): | def __init__(self, pseudo_label_list): | ||||
| super().__init__() | |||||
| self.pseudo_label_list = pseudo_label_list | |||||
| super().__init__(pseudo_label_list) | |||||
| self.prolog = pyswip.Prolog() | self.prolog = pyswip.Prolog() | ||||
| self.max_err = 0 | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| @@ -326,46 +319,7 @@ class HED_prolog_KB(prolog_KB): | |||||
| class RegKB(KBBase): | class RegKB(KBBase): | ||||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): | def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): | ||||
| super().__init__() | |||||
| self.GKB_flag = GKB_flag | |||||
| self.pseudo_label_list = pseudo_label_list | |||||
| self.len_list = len_list | |||||
| self.max_err = max_err | |||||
| if GKB_flag: | |||||
| self.base = {} | |||||
| X, Y = self._get_GKB() | |||||
| for x, y in zip(X, Y): | |||||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
| # For parallel version of _get_GKB | |||||
| def _get_XY_list(self, args): | |||||
| pre_x, post_x_it = args[0], args[1] | |||||
| XY_list = [] | |||||
| for post_x in post_x_it: | |||||
| x = (pre_x,) + post_x | |||||
| y = self.logic_forward(x) | |||||
| if y != np.inf: | |||||
| XY_list.append((x, y)) | |||||
| return XY_list | |||||
| # Parallel _get_GKB | |||||
| def _get_GKB(self): | |||||
| X, Y = [], [] | |||||
| for length in self.len_list: | |||||
| arg_list = [] | |||||
| for pre_x in self.pseudo_label_list: | |||||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||||
| arg_list.append((pre_x, post_x_it)) | |||||
| with Pool(processes=len(arg_list)) as pool: | |||||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||||
| for XY_list in ret_list: | |||||
| if len(XY_list) == 0: | |||||
| continue | |||||
| part_X, part_Y = zip(*XY_list) | |||||
| X.extend(part_X) | |||||
| Y.extend(part_Y) | |||||
| return X, Y | |||||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| @@ -461,18 +415,6 @@ class RegKB(KBBase): | |||||
| candidates.append(candidate) | candidates.append(candidate) | ||||
| return candidates | return candidates | ||||
| def _dict_len(self, dic): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(len(c) for c in dic.values()) | |||||
| def __len__(self): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(self._dict_len(v) for v in self.base.values()) | |||||
| class HWF_KB(RegKB): | class HWF_KB(RegKB): | ||||
| def __init__( | def __init__( | ||||