Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
930b549d4e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 56 additions and 114 deletions
  1. +56
    -114
      abl/abducer/kb.py

+ 56
- 114
abl/abducer/kb.py View File

@@ -25,8 +25,46 @@ import pyswip


class KBBase(ABC):
def __init__(self, pseudo_label_list=None):
pass
def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0):
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.GKB_flag = GKB_flag
self.max_err = max_err

if GKB_flag:
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)

# For parallel version of _get_GKB
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y != np.inf:
XY_list.append((x, y))
return XY_list

# Parallel _get_GKB
def _get_GKB(self):
X, Y = [], []
for length in self.len_list:
arg_list = []
for pre_x in self.pseudo_label_list:
post_x_it = product(self.pseudo_label_list, repeat=length - 1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
return X, Y

@abstractmethod
def logic_forward(self):
@@ -87,53 +125,22 @@ class KBBase(ABC):

return candidates, min_address_num, address_num

def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
pass
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())


class ClsKB(KBBase):
def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None):
super().__init__()
self.GKB_flag = GKB_flag
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.max_err = 0

if GKB_flag:
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)


# For parallel version of _get_GKB
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y != np.inf:
XY_list.append((x, y))
return XY_list

# Parallel _get_GKB
def _get_GKB(self):
X, Y = [], []
for length in self.len_list:
arg_list = []
for pre_x in self.pseudo_label_list:
post_x_it = product(self.pseudo_label_list, repeat=length - 1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
return X, Y
def __init__(self, pseudo_label_list, len_list, GKB_flag):
super().__init__(pseudo_label_list, len_list, GKB_flag)

def logic_forward(self):
pass
@@ -208,22 +215,10 @@ class ClsKB(KBBase):
candidates.append(candidate)
return candidates

def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())


class add_KB(ClsKB):
def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]):
super().__init__(GKB_flag, pseudo_label_list, len_list)
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False):
super().__init__(pseudo_label_list, len_list, GKB_flag)

def logic_forward(self, nums):
return sum(nums)
@@ -231,10 +226,8 @@ class add_KB(ClsKB):

class prolog_KB(KBBase):
def __init__(self, pseudo_label_list):
super().__init__()
self.pseudo_label_list = pseudo_label_list
super().__init__(pseudo_label_list)
self.prolog = pyswip.Prolog()
self.max_err = 0

def logic_forward(self):
pass
@@ -326,46 +319,7 @@ class HED_prolog_KB(prolog_KB):

class RegKB(KBBase):
def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3):
super().__init__()
self.GKB_flag = GKB_flag
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.max_err = max_err

if GKB_flag:
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)

# For parallel version of _get_GKB
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y != np.inf:
XY_list.append((x, y))
return XY_list

# Parallel _get_GKB
def _get_GKB(self):
X, Y = [], []
for length in self.len_list:
arg_list = []
for pre_x in self.pseudo_label_list:
post_x_it = product(self.pseudo_label_list, repeat=length - 1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
return X, Y
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err)

def logic_forward(self):
pass
@@ -461,18 +415,6 @@ class RegKB(KBBase):
candidates.append(candidate)
return candidates
def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())

class HWF_KB(RegKB):
def __init__(


Loading…
Cancel
Save