| @@ -22,6 +22,8 @@ from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from utils.utils import _flatten, _reform_ids, _hamming_dist | |||
| from multiprocessing import Pool | |||
| import pyswip | |||
| @@ -102,18 +104,44 @@ class ClsKB(KBBase): | |||
| for address_num in range(max(self.len_list) + 1): | |||
| self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| def _get_GKB(self, pseudo_label_list, len_list): | |||
| all_X = [] | |||
| for len in len_list: | |||
| all_X += list(product(pseudo_label_list, repeat = len)) | |||
| X = [] | |||
| Y = [] | |||
| for x in all_X: | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| XY_list = [] | |||
| for post_x in post_x_it: | |||
| x = (pre_x,) + post_x | |||
| y = self.logic_forward(x) | |||
| if y != np.inf: | |||
| X.append(x) | |||
| Y.append(y) | |||
| XY_list.append((x,y)) | |||
| return XY_list | |||
| # Parallel get GKB | |||
| def _get_GKB(self, pseudo_label_list, len_list): | |||
| # all_X = [] | |||
| # for length in len_list: | |||
| # all_X += list(product(pseudo_label_list, repeat = length)) | |||
| # X, Y = [], [] | |||
| # for x in all_X: | |||
| # y = self.logic_forward(x) | |||
| # if y != np.inf: | |||
| # X.append(x) | |||
| # Y.append(y) | |||
| X, Y = [], [] | |||
| for length in len_list: | |||
| arg_list = [] | |||
| for pre_x in pseudo_label_list: | |||
| post_x_it = product(pseudo_label_list, repeat = length-1) | |||
| arg_list.append((pre_x, post_x_it)) | |||
| with Pool(processes=len(arg_list)) as pool: | |||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||
| for XY_list in ret_list: | |||
| if len(XY_list) == 0: | |||
| continue | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| return X, Y | |||
| def logic_forward(self): | |||
| @@ -366,7 +394,10 @@ class RegKB(KBBase): | |||
| import time | |||
| if __name__ == "__main__": | |||
| pass | |||
| t1 = time.time() | |||
| kb = HWF_KB(True) | |||
| t2 = time.time() | |||
| print(t2-t1) | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||