| @@ -25,8 +25,46 @@ import pyswip | |||
| class KBBase(ABC): | |||
| def __init__(self, pseudo_label_list=None): | |||
| pass | |||
| def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.len_list = len_list | |||
| self.GKB_flag = GKB_flag | |||
| self.max_err = max_err | |||
| if GKB_flag: | |||
| self.base = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| XY_list = [] | |||
| for post_x in post_x_it: | |||
| x = (pre_x,) + post_x | |||
| y = self.logic_forward(x) | |||
| if y != np.inf: | |||
| XY_list.append((x, y)) | |||
| return XY_list | |||
| # Parallel _get_GKB | |||
| def _get_GKB(self): | |||
| X, Y = [], [] | |||
| for length in self.len_list: | |||
| arg_list = [] | |||
| for pre_x in self.pseudo_label_list: | |||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
| arg_list.append((pre_x, post_x_it)) | |||
| with Pool(processes=len(arg_list)) as pool: | |||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||
| for XY_list in ret_list: | |||
| if len(XY_list) == 0: | |||
| continue | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| return X, Y | |||
| @abstractmethod | |||
| def logic_forward(self): | |||
| @@ -87,53 +125,22 @@ class KBBase(ABC): | |||
| return candidates, min_address_num, address_num | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| pass | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class ClsKB(KBBase): | |||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None): | |||
| super().__init__() | |||
| self.GKB_flag = GKB_flag | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.len_list = len_list | |||
| self.max_err = 0 | |||
| if GKB_flag: | |||
| self.base = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| XY_list = [] | |||
| for post_x in post_x_it: | |||
| x = (pre_x,) + post_x | |||
| y = self.logic_forward(x) | |||
| if y != np.inf: | |||
| XY_list.append((x, y)) | |||
| return XY_list | |||
| # Parallel _get_GKB | |||
| def _get_GKB(self): | |||
| X, Y = [], [] | |||
| for length in self.len_list: | |||
| arg_list = [] | |||
| for pre_x in self.pseudo_label_list: | |||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
| arg_list.append((pre_x, post_x_it)) | |||
| with Pool(processes=len(arg_list)) as pool: | |||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||
| for XY_list in ret_list: | |||
| if len(XY_list) == 0: | |||
| continue | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| return X, Y | |||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def logic_forward(self): | |||
| pass | |||
| @@ -208,22 +215,10 @@ class ClsKB(KBBase): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class add_KB(ClsKB): | |||
| def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| @@ -231,10 +226,8 @@ class add_KB(ClsKB): | |||
| class prolog_KB(KBBase): | |||
| def __init__(self, pseudo_label_list): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog = pyswip.Prolog() | |||
| self.max_err = 0 | |||
| def logic_forward(self): | |||
| pass | |||
| @@ -326,46 +319,7 @@ class HED_prolog_KB(prolog_KB): | |||
| class RegKB(KBBase): | |||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): | |||
| super().__init__() | |||
| self.GKB_flag = GKB_flag | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.len_list = len_list | |||
| self.max_err = max_err | |||
| if GKB_flag: | |||
| self.base = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| XY_list = [] | |||
| for post_x in post_x_it: | |||
| x = (pre_x,) + post_x | |||
| y = self.logic_forward(x) | |||
| if y != np.inf: | |||
| XY_list.append((x, y)) | |||
| return XY_list | |||
| # Parallel _get_GKB | |||
| def _get_GKB(self): | |||
| X, Y = [], [] | |||
| for length in self.len_list: | |||
| arg_list = [] | |||
| for pre_x in self.pseudo_label_list: | |||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
| arg_list.append((pre_x, post_x_it)) | |||
| with Pool(processes=len(arg_list)) as pool: | |||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||
| for XY_list in ret_list: | |||
| if len(XY_list) == 0: | |||
| continue | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| return X, Y | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) | |||
| def logic_forward(self): | |||
| pass | |||
| @@ -461,18 +415,6 @@ class RegKB(KBBase): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class HWF_KB(RegKB): | |||
| def __init__( | |||