| @@ -10,20 +10,6 @@ | |||
| # | |||
| # ================================================================# | |||
| from abc import ABC, abstractmethod | |||
| import bisect | |||
| import copy | |||
| import numpy as np | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal | |||
| from multiprocessing import Pool | |||
| import pyswip | |||
| class KBBase(ABC): | |||
| def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): | |||
| self.pseudo_label_list = pseudo_label_list | |||
| @@ -79,78 +65,16 @@ class KBBase(ABC): | |||
| res.append(self.logic_forward(x)) | |||
| return res | |||
| @abstractmethod | |||
| def abduce_candidates(self): | |||
| pass | |||
| @abstractmethod | |||
| def address_by_idx(self): | |||
| pass | |||
| def _address(self, address_num, pred_res, key, multiple_predictions): | |||
| new_candidates = [] | |||
| if not multiple_predictions: | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| else: | |||
| address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class ClsKB(KBBase): | |||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| if self.GKB_flag: | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| else: | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| @abstractmethod | |||
| def _find_candidate_GKB(self): | |||
| pass | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| if self.base == {}: | |||
| return [], 0, 0 | |||
| @@ -158,7 +82,7 @@ class ClsKB(KBBase): | |||
| if not multiple_predictions: | |||
| if len(pred_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self.base[len(pred_res)][key] | |||
| all_candidates = self._find_candidate_GKB(pred_res, key) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| @@ -168,7 +92,7 @@ class ClsKB(KBBase): | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| else: | |||
| min_address_num = 0 | |||
| all_candidates_save = [] | |||
| @@ -177,7 +101,7 @@ class ClsKB(KBBase): | |||
| for p_res, k in zip(pred_res, key): | |||
| if len(p_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self.base[len(p_res)][k] | |||
| all_candidates = self._regression_find_candidate_GKB(p_res, k) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| @@ -194,7 +118,7 @@ class ClsKB(KBBase): | |||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||
| @@ -211,10 +135,71 @@ class ClsKB(KBBase): | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key): | |||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| def _address(self, address_num, pred_res, key, multiple_predictions): | |||
| new_candidates = [] | |||
| if not multiple_predictions: | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| else: | |||
| address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class ClsKB(KBBase): | |||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def logic_forward(self): | |||
| pass | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| return self.base[len(pred_res)][key] | |||
| class add_KB(ClsKB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | |||
| @@ -232,16 +217,13 @@ class prolog_KB(KBBase): | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def _find_candidate_GKB(self): | |||
| pass | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| # print(address_idx) | |||
| if not multiple_predictions: | |||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||
| else: | |||
| query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) | |||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| @@ -284,21 +266,18 @@ class HED_prolog_KB(prolog_KB): | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog.consult('./datasets/hed/learn_add.pl') | |||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||
| def logic_forward(self, exs): | |||
| return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 | |||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | |||
| # flatten | |||
| def get_query_string(self, pred_res, key, address_idx): | |||
| flatten_pred_res = flatten(pred_res) | |||
| # add variables for prolog | |||
| for idx in range(len(flatten_pred_res)): | |||
| if idx in address_idx: | |||
| flatten_pred_res[idx] = 'X' + str(idx) | |||
| # unflatten | |||
| new_pred_res = reform_idx(flatten_pred_res, pred_res) | |||
| pred_res = reform_idx(flatten_pred_res, pred_res) | |||
| query_string = "abduce_consistent_insts(%s)." % new_pred_res | |||
| query_string = "abduce_consistent_insts(%s)." % pred_res | |||
| return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | |||
| def consist_rule(self, exs, rules): | |||
| @@ -324,13 +303,7 @@ class RegKB(KBBase): | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| if self.GKB_flag: | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| else: | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def _regression_find_candidate_GKB(self, pred_res, key): | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| potential_candidates = self.base[len(pred_res)] | |||
| key_list = sorted(potential_candidates) | |||
| key_idx = bisect.bisect_left(key_list, key) | |||
| @@ -351,70 +324,6 @@ class RegKB(KBBase): | |||
| break | |||
| return all_candidates | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| if self.base == {}: | |||
| return [], 0, 0 | |||
| if not multiple_predictions: | |||
| if len(pred_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self._regression_find_candidate_GKB(pred_res, key) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| else: | |||
| min_address_num = 0 | |||
| all_candidates_save = [] | |||
| cost_list_save = [] | |||
| for p_res, k in zip(pred_res, key): | |||
| if len(p_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self._regression_find_candidate_GKB(p_res, k) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| all_candidates_save.append(all_candidates) | |||
| cost_list = hamming_dist(p_res, all_candidates) | |||
| min_address_num += np.min(cost_list) | |||
| cost_list_save.append(cost_list) | |||
| multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] | |||
| assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) | |||
| multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) | |||
| assert len(multiple_all_candidates) == len(multiple_cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = flatten(pred_res) | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| class HWF_KB(RegKB): | |||
| def __init__( | |||
| @@ -456,13 +365,6 @@ class HWF_KB(RegKB): | |||
| formula = [mapping[f] for f in formula] | |||
| return round(eval(''.join(formula)), 2) | |||
| import time | |||
| if __name__ == "__main__": | |||
| t1 = time.time() | |||
| kb = add_KB(GKB_flag=True) | |||
| t2 = time.time() | |||
| print(t2 - t1) | |||
| pass | |||