| @@ -10,20 +10,6 @@ | |||||
| # | # | ||||
| # ================================================================# | # ================================================================# | ||||
| from abc import ABC, abstractmethod | |||||
| import bisect | |||||
| import copy | |||||
| import numpy as np | |||||
| from collections import defaultdict | |||||
| from itertools import product, combinations | |||||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal | |||||
| from multiprocessing import Pool | |||||
| import pyswip | |||||
| class KBBase(ABC): | class KBBase(ABC): | ||||
| def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): | def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): | ||||
| self.pseudo_label_list = pseudo_label_list | self.pseudo_label_list = pseudo_label_list | ||||
| @@ -79,78 +65,16 @@ class KBBase(ABC): | |||||
| res.append(self.logic_forward(x)) | res.append(self.logic_forward(x)) | ||||
| return res | return res | ||||
| @abstractmethod | |||||
| def abduce_candidates(self): | |||||
| pass | |||||
| @abstractmethod | |||||
| def address_by_idx(self): | |||||
| pass | |||||
| def _address(self, address_num, pred_res, key, multiple_predictions): | |||||
| new_candidates = [] | |||||
| if not multiple_predictions: | |||||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||||
| else: | |||||
| address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | |||||
| for address_idx in address_idx_list: | |||||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||||
| new_candidates += candidates | |||||
| return new_candidates | |||||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| candidates = [] | |||||
| for address_num in range(len(flatten(pred_res)) + 1): | |||||
| if address_num == 0: | |||||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||||
| candidates.append(pred_res) | |||||
| else: | |||||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||||
| candidates += new_candidates | |||||
| if len(candidates) > 0: | |||||
| min_address_num = address_num | |||||
| break | |||||
| if address_num >= max_address_num: | |||||
| return [], 0, 0 | |||||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||||
| if address_num > max_address_num: | |||||
| return candidates, min_address_num, address_num - 1 | |||||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||||
| candidates += new_candidates | |||||
| return candidates, min_address_num, address_num | |||||
| def _dict_len(self, dic): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(len(c) for c in dic.values()) | |||||
| def __len__(self): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(self._dict_len(v) for v in self.base.values()) | |||||
| class ClsKB(KBBase): | |||||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||||
| def logic_forward(self): | |||||
| pass | |||||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | ||||
| if self.GKB_flag: | if self.GKB_flag: | ||||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | ||||
| else: | else: | ||||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | ||||
| @abstractmethod | |||||
| def _find_candidate_GKB(self): | |||||
| pass | |||||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | ||||
| if self.base == {}: | if self.base == {}: | ||||
| return [], 0, 0 | return [], 0, 0 | ||||
| @@ -158,7 +82,7 @@ class ClsKB(KBBase): | |||||
| if not multiple_predictions: | if not multiple_predictions: | ||||
| if len(pred_res) not in self.len_list: | if len(pred_res) not in self.len_list: | ||||
| return [], 0, 0 | return [], 0, 0 | ||||
| all_candidates = self.base[len(pred_res)][key] | |||||
| all_candidates = self._find_candidate_GKB(pred_res, key) | |||||
| if len(all_candidates) == 0: | if len(all_candidates) == 0: | ||||
| return [], 0, 0 | return [], 0, 0 | ||||
| else: | else: | ||||
| @@ -168,7 +92,7 @@ class ClsKB(KBBase): | |||||
| idxs = np.where(cost_list <= address_num)[0] | idxs = np.where(cost_list <= address_num)[0] | ||||
| candidates = [all_candidates[idx] for idx in idxs] | candidates = [all_candidates[idx] for idx in idxs] | ||||
| return candidates, min_address_num, address_num | return candidates, min_address_num, address_num | ||||
| else: | else: | ||||
| min_address_num = 0 | min_address_num = 0 | ||||
| all_candidates_save = [] | all_candidates_save = [] | ||||
| @@ -177,7 +101,7 @@ class ClsKB(KBBase): | |||||
| for p_res, k in zip(pred_res, key): | for p_res, k in zip(pred_res, key): | ||||
| if len(p_res) not in self.len_list: | if len(p_res) not in self.len_list: | ||||
| return [], 0, 0 | return [], 0, 0 | ||||
| all_candidates = self.base[len(p_res)][k] | |||||
| all_candidates = self._regression_find_candidate_GKB(p_res, k) | |||||
| if len(all_candidates) == 0: | if len(all_candidates) == 0: | ||||
| return [], 0, 0 | return [], 0, 0 | ||||
| else: | else: | ||||
| @@ -194,7 +118,7 @@ class ClsKB(KBBase): | |||||
| idxs = np.where(multiple_cost_list <= address_num)[0] | idxs = np.where(multiple_cost_list <= address_num)[0] | ||||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | ||||
| return candidates, min_address_num, address_num | return candidates, min_address_num, address_num | ||||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | ||||
| candidates = [] | candidates = [] | ||||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | ||||
| @@ -211,10 +135,71 @@ class ClsKB(KBBase): | |||||
| if multiple_predictions: | if multiple_predictions: | ||||
| candidate = reform_idx(candidate, save_pred_res) | candidate = reform_idx(candidate, save_pred_res) | ||||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key): | |||||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||||
| candidates.append(candidate) | candidates.append(candidate) | ||||
| return candidates | return candidates | ||||
| def _address(self, address_num, pred_res, key, multiple_predictions): | |||||
| new_candidates = [] | |||||
| if not multiple_predictions: | |||||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||||
| else: | |||||
| address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | |||||
| for address_idx in address_idx_list: | |||||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||||
| new_candidates += candidates | |||||
| return new_candidates | |||||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| candidates = [] | |||||
| for address_num in range(len(flatten(pred_res)) + 1): | |||||
| if address_num == 0: | |||||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||||
| candidates.append(pred_res) | |||||
| else: | |||||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||||
| candidates += new_candidates | |||||
| if len(candidates) > 0: | |||||
| min_address_num = address_num | |||||
| break | |||||
| if address_num >= max_address_num: | |||||
| return [], 0, 0 | |||||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||||
| if address_num > max_address_num: | |||||
| return candidates, min_address_num, address_num - 1 | |||||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||||
| candidates += new_candidates | |||||
| return candidates, min_address_num, address_num | |||||
| def _dict_len(self, dic): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(len(c) for c in dic.values()) | |||||
| def __len__(self): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(self._dict_len(v) for v in self.base.values()) | |||||
| class ClsKB(KBBase): | |||||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||||
| def logic_forward(self): | |||||
| pass | |||||
| def _find_candidate_GKB(self, pred_res, key): | |||||
| return self.base[len(pred_res)][key] | |||||
| class add_KB(ClsKB): | class add_KB(ClsKB): | ||||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | ||||
| @@ -232,16 +217,13 @@ class prolog_KB(KBBase): | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| def _find_candidate_GKB(self): | |||||
| pass | |||||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | ||||
| candidates = [] | candidates = [] | ||||
| # print(address_idx) | # print(address_idx) | ||||
| if not multiple_predictions: | |||||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||||
| else: | |||||
| query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) | |||||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||||
| if multiple_predictions: | if multiple_predictions: | ||||
| save_pred_res = pred_res | save_pred_res = pred_res | ||||
| @@ -284,21 +266,18 @@ class HED_prolog_KB(prolog_KB): | |||||
| super().__init__(pseudo_label_list) | super().__init__(pseudo_label_list) | ||||
| self.prolog.consult('./datasets/hed/learn_add.pl') | self.prolog.consult('./datasets/hed/learn_add.pl') | ||||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||||
| def logic_forward(self, exs): | def logic_forward(self, exs): | ||||
| return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 | return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 | ||||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | |||||
| # flatten | |||||
| def get_query_string(self, pred_res, key, address_idx): | |||||
| flatten_pred_res = flatten(pred_res) | flatten_pred_res = flatten(pred_res) | ||||
| # add variables for prolog | # add variables for prolog | ||||
| for idx in range(len(flatten_pred_res)): | for idx in range(len(flatten_pred_res)): | ||||
| if idx in address_idx: | if idx in address_idx: | ||||
| flatten_pred_res[idx] = 'X' + str(idx) | flatten_pred_res[idx] = 'X' + str(idx) | ||||
| # unflatten | |||||
| new_pred_res = reform_idx(flatten_pred_res, pred_res) | |||||
| pred_res = reform_idx(flatten_pred_res, pred_res) | |||||
| query_string = "abduce_consistent_insts(%s)." % new_pred_res | |||||
| query_string = "abduce_consistent_insts(%s)." % pred_res | |||||
| return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | ||||
| def consist_rule(self, exs, rules): | def consist_rule(self, exs, rules): | ||||
| @@ -324,13 +303,7 @@ class RegKB(KBBase): | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||||
| if self.GKB_flag: | |||||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| else: | |||||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| def _regression_find_candidate_GKB(self, pred_res, key): | |||||
| def _find_candidate_GKB(self, pred_res, key): | |||||
| potential_candidates = self.base[len(pred_res)] | potential_candidates = self.base[len(pred_res)] | ||||
| key_list = sorted(potential_candidates) | key_list = sorted(potential_candidates) | ||||
| key_idx = bisect.bisect_left(key_list, key) | key_idx = bisect.bisect_left(key_list, key) | ||||
| @@ -351,70 +324,6 @@ class RegKB(KBBase): | |||||
| break | break | ||||
| return all_candidates | return all_candidates | ||||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| if self.base == {}: | |||||
| return [], 0, 0 | |||||
| if not multiple_predictions: | |||||
| if len(pred_res) not in self.len_list: | |||||
| return [], 0, 0 | |||||
| all_candidates = self._regression_find_candidate_GKB(pred_res, key) | |||||
| if len(all_candidates) == 0: | |||||
| return [], 0, 0 | |||||
| else: | |||||
| cost_list = hamming_dist(pred_res, all_candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||||
| idxs = np.where(cost_list <= address_num)[0] | |||||
| candidates = [all_candidates[idx] for idx in idxs] | |||||
| return candidates, min_address_num, address_num | |||||
| else: | |||||
| min_address_num = 0 | |||||
| all_candidates_save = [] | |||||
| cost_list_save = [] | |||||
| for p_res, k in zip(pred_res, key): | |||||
| if len(p_res) not in self.len_list: | |||||
| return [], 0, 0 | |||||
| all_candidates = self._regression_find_candidate_GKB(p_res, k) | |||||
| if len(all_candidates) == 0: | |||||
| return [], 0, 0 | |||||
| else: | |||||
| all_candidates_save.append(all_candidates) | |||||
| cost_list = hamming_dist(p_res, all_candidates) | |||||
| min_address_num += np.min(cost_list) | |||||
| cost_list_save.append(cost_list) | |||||
| multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] | |||||
| assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) | |||||
| multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) | |||||
| assert len(multiple_all_candidates) == len(multiple_cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||||
| return candidates, min_address_num, address_num | |||||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||||
| candidates = [] | |||||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||||
| if multiple_predictions: | |||||
| save_pred_res = pred_res | |||||
| pred_res = flatten(pred_res) | |||||
| for c in abduce_c: | |||||
| candidate = pred_res.copy() | |||||
| for i, idx in enumerate(address_idx): | |||||
| candidate[idx] = c[i] | |||||
| if multiple_predictions: | |||||
| candidate = reform_idx(candidate, save_pred_res) | |||||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||||
| candidates.append(candidate) | |||||
| return candidates | |||||
| class HWF_KB(RegKB): | class HWF_KB(RegKB): | ||||
| def __init__( | def __init__( | ||||
| @@ -456,13 +365,6 @@ class HWF_KB(RegKB): | |||||
| formula = [mapping[f] for f in formula] | formula = [mapping[f] for f in formula] | ||||
| return round(eval(''.join(formula)), 2) | return round(eval(''.join(formula)), 2) | ||||
| import time | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| t1 = time.time() | |||||
| kb = add_KB(GKB_flag=True) | |||||
| t2 = time.time() | |||||
| print(t2 - t1) | |||||
| pass | |||||