Browse Source

update parallel versio of get GKB

pull/3/head
Tony-HYX 3 years ago
parent
commit
6a88da1034
1 changed files with 42 additions and 11 deletions
  1. +42
    -11
      abducer/kb.py

+ 42
- 11
abducer/kb.py View File

@@ -22,6 +22,8 @@ from collections import defaultdict
from itertools import product, combinations
from utils.utils import _flatten, _reform_ids, _hamming_dist

from multiprocessing import Pool

import pyswip


@@ -102,18 +104,44 @@ class ClsKB(KBBase):
for address_num in range(max(self.len_list) + 1):
self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num))
def _get_GKB(self, pseudo_label_list, len_list):
all_X = []
for len in len_list:
all_X += list(product(pseudo_label_list, repeat = len))
X = []
Y = []
for x in all_X:
# For parallel version of _get_GKB
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y != np.inf:
X.append(x)
Y.append(y)
XY_list.append((x,y))
return XY_list

# Parallel get GKB
def _get_GKB(self, pseudo_label_list, len_list):
# all_X = []
# for length in len_list:
# all_X += list(product(pseudo_label_list, repeat = length))
# X, Y = [], []
# for x in all_X:
# y = self.logic_forward(x)
# if y != np.inf:
# X.append(x)
# Y.append(y)

X, Y = [], []
for length in len_list:
arg_list = []
for pre_x in pseudo_label_list:
post_x_it = product(pseudo_label_list, repeat = length-1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
return X, Y
def logic_forward(self):
@@ -366,7 +394,10 @@ class RegKB(KBBase):

import time
if __name__ == "__main__":
pass
t1 = time.time()
kb = HWF_KB(True)
t2 = time.time()
print(t2-t1)
# X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]


Loading…
Cancel
Save