Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
650c172d61
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 19 deletions
  1. +8
    -19
      abducer/kb.py

+ 8
- 19
abducer/kb.py View File

@@ -21,7 +21,7 @@ sys.path.append("..")

from collections import defaultdict
from itertools import product, combinations
from utils.utils import flatten, reform_idx, hamming_dist, check_is_equal
from utils.utils import flatten, reform_idx, hamming_dist, check_equal

from multiprocessing import Pool

@@ -56,12 +56,12 @@ class KBBase(ABC):
new_candidates += candidates
return new_candidates

def _abduce_by_abduction(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False):
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False):
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_is_equal(pred_res, key):
if check_equal(self.logic_forward(pred_res), key):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
@@ -114,19 +114,8 @@ class ClsKB(KBBase):
XY_list.append((x, y))
return XY_list

# Parallel get GKB
# Parallel _get_GKB
def _get_GKB(self):
# all_X = []
# for length in len_list:
# all_X += list(product(self.pseudo_label_list, repeat = length))

# X, Y = [], []
# for x in all_X:
# y = self.logic_forward(x)
# if y != np.inf:
# X.append(x)
# Y.append(y)

X, Y = [], []
for length in self.len_list:
arg_list = []
@@ -148,11 +137,11 @@ class ClsKB(KBBase):

def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False):
if self.GKB_flag:
return self._abduce_from_GKB(pred_res, key, max_address_num, require_more_address)
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address)
else:
return self._abduce_by_abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions)
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)

def _abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address):
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address):
if self.base == {} or len(pred_res) not in self.len_list:
return []

@@ -260,7 +249,7 @@ class prolog_KB(KBBase):
pass

def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
return self._abduce_by_abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions)
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)

def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []


Loading…
Cancel
Save