| @@ -21,7 +21,7 @@ sys.path.append("..") | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from utils.utils import flatten, reform_idx, hamming_dist, check_is_equal | |||
| from utils.utils import flatten, reform_idx, hamming_dist, check_equal | |||
| from multiprocessing import Pool | |||
| @@ -56,12 +56,12 @@ class KBBase(ABC): | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def _abduce_by_abduction(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| if address_num == 0: | |||
| if check_is_equal(pred_res, key): | |||
| if check_equal(self.logic_forward(pred_res), key): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| @@ -114,19 +114,8 @@ class ClsKB(KBBase): | |||
| XY_list.append((x, y)) | |||
| return XY_list | |||
| # Parallel get GKB | |||
| # Parallel _get_GKB | |||
| def _get_GKB(self): | |||
| # all_X = [] | |||
| # for length in len_list: | |||
| # all_X += list(product(self.pseudo_label_list, repeat = length)) | |||
| # X, Y = [], [] | |||
| # for x in all_X: | |||
| # y = self.logic_forward(x) | |||
| # if y != np.inf: | |||
| # X.append(x) | |||
| # Y.append(y) | |||
| X, Y = [], [] | |||
| for length in self.len_list: | |||
| arg_list = [] | |||
| @@ -148,11 +137,11 @@ class ClsKB(KBBase): | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| if self.GKB_flag: | |||
| return self._abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) | |||
| else: | |||
| return self._abduce_by_abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def _abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.base == {} or len(pred_res) not in self.len_list: | |||
| return [] | |||
| @@ -260,7 +249,7 @@ class prolog_KB(KBBase): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| return self._abduce_by_abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||