diff --git a/abducer/kb.py b/abducer/kb.py index f44ffec..9e28283 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -22,6 +22,8 @@ from collections import defaultdict from itertools import product, combinations from utils.utils import _flatten, _reform_ids, _hamming_dist +from multiprocessing import Pool + import pyswip @@ -102,18 +104,44 @@ class ClsKB(KBBase): for address_num in range(max(self.len_list) + 1): self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num)) - def _get_GKB(self, pseudo_label_list, len_list): - all_X = [] - for len in len_list: - all_X += list(product(pseudo_label_list, repeat = len)) - - X = [] - Y = [] - for x in all_X: + # For parallel version of _get_GKB + def _get_XY_list(self, args): + pre_x, post_x_it = args[0], args[1] + XY_list = [] + for post_x in post_x_it: + x = (pre_x,) + post_x y = self.logic_forward(x) if y != np.inf: - X.append(x) - Y.append(y) + XY_list.append((x,y)) + return XY_list + + # Parallel get GKB + def _get_GKB(self, pseudo_label_list, len_list): + # all_X = [] + # for length in len_list: + # all_X += list(product(pseudo_label_list, repeat = length)) + + # X, Y = [], [] + # for x in all_X: + # y = self.logic_forward(x) + # if y != np.inf: + # X.append(x) + # Y.append(y) + + X, Y = [], [] + for length in len_list: + arg_list = [] + for pre_x in pseudo_label_list: + post_x_it = product(pseudo_label_list, repeat = length-1) + arg_list.append((pre_x, post_x_it)) + with Pool(processes=len(arg_list)) as pool: + ret_list = pool.map(self._get_XY_list, arg_list) + for XY_list in ret_list: + if len(XY_list) == 0: + continue + part_X, part_Y = zip(*XY_list) + X.extend(part_X) + Y.extend(part_Y) return X, Y def logic_forward(self): @@ -366,7 +394,10 @@ class RegKB(KBBase): import time if __name__ == "__main__": - pass + t1 = time.time() + kb = HWF_KB(True) + t2 = time.time() + print(t2-t1) # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]