| @@ -31,6 +31,15 @@ class KBBase(ABC): | |||||
| @abstractmethod | @abstractmethod | ||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| def _logic_forward(self, xs, multiple_predictions=False): | |||||
| if not multiple_predictions: | |||||
| return self.logic_forward(xs) | |||||
| else: | |||||
| res = [] | |||||
| for x in xs: | |||||
| res.append(self.logic_forward(x)) | |||||
| return res | |||||
| @abstractmethod | @abstractmethod | ||||
| def abduce_candidates(self): | def abduce_candidates(self): | ||||
| @@ -40,7 +49,7 @@ class KBBase(ABC): | |||||
| def address_by_idx(self): | def address_by_idx(self): | ||||
| pass | pass | ||||
| def _address(self, address_num, pred_res, key, multiple_predictions=False): | |||||
| def _address(self, address_num, pred_res, key, multiple_predictions): | |||||
| new_candidates = [] | new_candidates = [] | ||||
| if not multiple_predictions: | if not multiple_predictions: | ||||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | ||||
| @@ -52,12 +61,12 @@ class KBBase(ABC): | |||||
| new_candidates += candidates | new_candidates += candidates | ||||
| return new_candidates | return new_candidates | ||||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): | |||||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| candidates = [] | candidates = [] | ||||
| for address_num in range(len(flatten(pred_res)) + 1): | for address_num in range(len(flatten(pred_res)) + 1): | ||||
| if address_num == 0: | if address_num == 0: | ||||
| if check_equal(self.logic_forward(pred_res), key): | |||||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||||
| candidates.append(pred_res) | candidates.append(pred_res) | ||||
| else: | else: | ||||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | ||||
| @@ -88,16 +97,14 @@ class ClsKB(KBBase): | |||||
| self.GKB_flag = GKB_flag | self.GKB_flag = GKB_flag | ||||
| self.pseudo_label_list = pseudo_label_list | self.pseudo_label_list = pseudo_label_list | ||||
| self.len_list = len_list | self.len_list = len_list | ||||
| self.max_err = 0 | |||||
| if GKB_flag: | if GKB_flag: | ||||
| self.base = {} | self.base = {} | ||||
| X, Y = self._get_GKB() | X, Y = self._get_GKB() | ||||
| for x, y in zip(X, Y): | for x, y in zip(X, Y): | ||||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | self.base.setdefault(len(x), defaultdict(list))[y].append(x) | ||||
| else: | |||||
| self.all_address_candidate_dict = {} | |||||
| for address_num in range(max(self.len_list) + 1): | |||||
| self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) | |||||
| # For parallel version of _get_GKB | # For parallel version of _get_GKB | ||||
| def _get_XY_list(self, args): | def _get_XY_list(self, args): | ||||
| @@ -133,33 +140,57 @@ class ClsKB(KBBase): | |||||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | ||||
| if self.GKB_flag: | if self.GKB_flag: | ||||
| # TODO: 这里有可能是multiple_predictions吗 | |||||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) | |||||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| else: | else: | ||||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | ||||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): | |||||
| if self.base == {} or len(pred_res) not in self.len_list: | |||||
| return [] | |||||
| all_candidates = self.base[len(pred_res)][key] | |||||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| if self.base == {}: | |||||
| return [], 0, 0 | |||||
| if len(all_candidates) == 0: | |||||
| candidates = [] | |||||
| min_address_num = 0 | |||||
| address_num = 0 | |||||
| if not multiple_predictions: | |||||
| if len(pred_res) not in self.len_list: | |||||
| return [], 0, 0 | |||||
| all_candidates = self.base[len(pred_res)][key] | |||||
| if len(all_candidates) == 0: | |||||
| return [], 0, 0 | |||||
| else: | |||||
| cost_list = hamming_dist(pred_res, all_candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||||
| idxs = np.where(cost_list <= address_num)[0] | |||||
| candidates = [all_candidates[idx] for idx in idxs] | |||||
| return candidates, min_address_num, address_num | |||||
| else: | else: | ||||
| cost_list = hamming_dist(pred_res, all_candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| min_address_num = 0 | |||||
| all_candidates_save = [] | |||||
| cost_list_save = [] | |||||
| for p_res, k in zip(pred_res, key): | |||||
| if len(p_res) not in self.len_list: | |||||
| return [], 0, 0 | |||||
| all_candidates = self.base[len(p_res)][k] | |||||
| if len(all_candidates) == 0: | |||||
| return [], 0, 0 | |||||
| else: | |||||
| all_candidates_save.append(all_candidates) | |||||
| cost_list = hamming_dist(p_res, all_candidates) | |||||
| min_address_num += np.min(cost_list) | |||||
| cost_list_save.append(cost_list) | |||||
| multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] | |||||
| assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) | |||||
| multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) | |||||
| assert len(multiple_all_candidates) == len(multiple_cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | address_num = min(max_address_num, min_address_num + require_more_address) | ||||
| idxs = np.where(cost_list <= address_num)[0] | |||||
| candidates = [all_candidates[idx] for idx in idxs] | |||||
| return candidates, min_address_num, address_num | |||||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||||
| return candidates, min_address_num, address_num | |||||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | ||||
| candidates = [] | candidates = [] | ||||
| abduce_c = self.all_address_candidate_dict[len(address_idx)] | |||||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||||
| if multiple_predictions: | if multiple_predictions: | ||||
| save_pred_res = pred_res | save_pred_res = pred_res | ||||
| @@ -173,7 +204,7 @@ class ClsKB(KBBase): | |||||
| if multiple_predictions: | if multiple_predictions: | ||||
| candidate = reform_idx(candidate, save_pred_res) | candidate = reform_idx(candidate, save_pred_res) | ||||
| if self.logic_forward(candidate) == key: | |||||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key): | |||||
| candidates.append(candidate) | candidates.append(candidate) | ||||
| return candidates | return candidates | ||||
| @@ -197,50 +228,13 @@ class add_KB(ClsKB): | |||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| # TODO:这是个回归任务(对于y而言),在logic_forward加round变成离散的分类任务固然可行,但最好还是用RegKB吧,作为例子示范。还需要对下面的ClsKB进行修改(见TODO) | |||||
| class HWF_KB(ClsKB): | |||||
| def __init__( | |||||
| self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] | |||||
| ): | |||||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||||
| def valid_candidate(self, formula): | |||||
| if len(formula) % 2 == 0: | |||||
| return False | |||||
| for i in range(len(formula)): | |||||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||||
| return False | |||||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||||
| return False | |||||
| return True | |||||
| def logic_forward(self, formula): | |||||
| if not self.valid_candidate(formula): | |||||
| return np.inf | |||||
| mapping = { | |||||
| '1': '1', | |||||
| '2': '2', | |||||
| '3': '3', | |||||
| '4': '4', | |||||
| '5': '5', | |||||
| '6': '6', | |||||
| '7': '7', | |||||
| '8': '8', | |||||
| '9': '9', | |||||
| '+': '+', | |||||
| '-': '-', | |||||
| 'times': '*', | |||||
| 'div': '/', | |||||
| } | |||||
| formula = [mapping[f] for f in formula] | |||||
| return round(eval(''.join(formula)), 2) | |||||
| class prolog_KB(KBBase): | class prolog_KB(KBBase): | ||||
| def __init__(self, pseudo_label_list): | def __init__(self, pseudo_label_list): | ||||
| super().__init__() | super().__init__() | ||||
| self.pseudo_label_list = pseudo_label_list | self.pseudo_label_list = pseudo_label_list | ||||
| self.prolog = pyswip.Prolog() | self.prolog = pyswip.Prolog() | ||||
| self.max_err = 0 | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| @@ -295,11 +289,11 @@ class add_prolog_KB(prolog_KB): | |||||
| class HED_prolog_KB(prolog_KB): | class HED_prolog_KB(prolog_KB): | ||||
| def __init__(self, pseudo_label_list=[0, 1, '+', '=']): | def __init__(self, pseudo_label_list=[0, 1, '+', '=']): | ||||
| super().__init__(pseudo_label_list) | super().__init__(pseudo_label_list) | ||||
| self.prolog.consult('../examples/datasets/hed/learn_add.pl') | |||||
| self.prolog.consult('./datasets/hed/learn_add.pl') | |||||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | ||||
| def logic_forward(self, exs): | def logic_forward(self, exs): | ||||
| return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 | |||||
| return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 | |||||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | def get_query_string_need_flatten(self, pred_res, key, address_idx): | ||||
| # flatten | # flatten | ||||
| @@ -329,93 +323,204 @@ class HED_prolog_KB(prolog_KB): | |||||
| rules.append(rule.value) | rules.append(rule.value) | ||||
| return rules | return rules | ||||
| # def consist_rules(self, pred_res, rules): | |||||
| # TODO:这里需要修改一下这个类,原本的RegKB是对GKB而言的,现在需要和ClsKB一样同时支持GKB和非GKB。需要补充非GKB部分(可能继承_abduce_by_search就行),以及修改GKB部分_abduce_by_GKB的逻辑(原本逻辑是找与key最近的y的abduce结果,现在改成与key在一定误差范围内的y的abduce结果) | |||||
| # TODO:我理解的RegKB是这样的: | |||||
| # TODO:1. 对GKB而言,即_abduce_by_GKB,给定key和length,还需要一个self.max_err,返回所有与key绝对值小于max_err的abduction结果 | |||||
| # TODO:比如GKB里的y有[1.3, 1.49, 1.50, 1.52, 1.6],若key=1.5,max_err=1e-5,则返回[y=1.50]的abduction结果;若key=1.5,max_err=0.05,则返回所有[y=1.49, 1.50, 1.52]的abduction结果 | |||||
| # TODO:因此在二分查找bisect_left后,需要分别往前和往后遍历,从GKB里找符合误差的y | |||||
| # TODO:self.max_err默认值取很小就行,比如HWF这类任务;但有些任务(比如法院刑期预测)的max_err需要大些,因此可以由用户自定义 | |||||
| # TODO:2. 对非GKB而言,估计直接用_abduce_by_search就行,check_equal那限定为数字且控制回归误差max_err | |||||
| class RegKB(KBBase): | class RegKB(KBBase): | ||||
| def __init__(self, GKB_flag=False, X=None, Y=None): | |||||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): | |||||
| super().__init__() | super().__init__() | ||||
| tmp_dict = {} | |||||
| for x, y in zip(X, Y): | |||||
| tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||||
| self.base = {} | |||||
| for l in tmp_dict.keys(): | |||||
| data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) | |||||
| X = [x for y, x in data] | |||||
| Y = [y for y, x in data] | |||||
| self.base[l] = (X, Y) | |||||
| def valid_candidate(self): | |||||
| pass | |||||
| self.GKB_flag = GKB_flag | |||||
| self.pseudo_label_list = pseudo_label_list | |||||
| self.len_list = len_list | |||||
| self.max_err = max_err | |||||
| if GKB_flag: | |||||
| self.base = {} | |||||
| X, Y = self._get_GKB() | |||||
| for x, y in zip(X, Y): | |||||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
| # For parallel version of _get_GKB | |||||
| def _get_XY_list(self, args): | |||||
| pre_x, post_x_it = args[0], args[1] | |||||
| XY_list = [] | |||||
| for post_x in post_x_it: | |||||
| x = (pre_x,) + post_x | |||||
| y = self.logic_forward(x) | |||||
| if y != np.inf: | |||||
| XY_list.append((x, y)) | |||||
| return XY_list | |||||
| # Parallel _get_GKB | |||||
| def _get_GKB(self): | |||||
| X, Y = [], [] | |||||
| for length in self.len_list: | |||||
| arg_list = [] | |||||
| for pre_x in self.pseudo_label_list: | |||||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||||
| arg_list.append((pre_x, post_x_it)) | |||||
| with Pool(processes=len(arg_list)) as pool: | |||||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||||
| for XY_list in ret_list: | |||||
| if len(XY_list) == 0: | |||||
| continue | |||||
| part_X, part_Y = zip(*XY_list) | |||||
| X.extend(part_X) | |||||
| Y.extend(part_Y) | |||||
| return X, Y | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| def _abduce_by_GKB(self, key, length=None): | |||||
| if key is None: | |||||
| return self.get_all_candidates() | |||||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||||
| if self.GKB_flag: | |||||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| else: | |||||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| length = self._length(length) | |||||
| def _regression_find_candidate_GKB(self, pred_res, key): | |||||
| potential_candidates = self.base[len(pred_res)] | |||||
| key_list = sorted(potential_candidates) | |||||
| key_idx = bisect.bisect_left(key_list, key) | |||||
| all_candidates = [] | |||||
| for idx in range(key_idx - 1, 0, -1): | |||||
| k = key_list[idx] | |||||
| if abs(k - key) <= self.max_err: | |||||
| all_candidates += potential_candidates[k] | |||||
| else: | |||||
| break | |||||
| for idx in range(key_idx, len(key_list)): | |||||
| k = key_list[idx] | |||||
| if abs(k - key) <= self.max_err: | |||||
| all_candidates += potential_candidates[k] | |||||
| else: | |||||
| break | |||||
| return all_candidates | |||||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| if self.base == {}: | |||||
| return [], 0, 0 | |||||
| min_err = 999999 | |||||
| if not multiple_predictions: | |||||
| if len(pred_res) not in self.len_list: | |||||
| return [], 0, 0 | |||||
| all_candidates = self._regression_find_candidate_GKB(pred_res, key) | |||||
| if len(all_candidates) == 0: | |||||
| return [], 0, 0 | |||||
| else: | |||||
| cost_list = hamming_dist(pred_res, all_candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||||
| idxs = np.where(cost_list <= address_num)[0] | |||||
| candidates = [all_candidates[idx] for idx in idxs] | |||||
| return candidates, min_address_num, address_num | |||||
| else: | |||||
| min_address_num = 0 | |||||
| all_candidates_save = [] | |||||
| cost_list_save = [] | |||||
| for p_res, k in zip(pred_res, key): | |||||
| if len(p_res) not in self.len_list: | |||||
| return [], 0, 0 | |||||
| all_candidates = self._regression_find_candidate_GKB(p_res, k) | |||||
| if len(all_candidates) == 0: | |||||
| return [], 0, 0 | |||||
| else: | |||||
| all_candidates_save.append(all_candidates) | |||||
| cost_list = hamming_dist(p_res, all_candidates) | |||||
| min_address_num += np.min(cost_list) | |||||
| cost_list_save.append(cost_list) | |||||
| multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] | |||||
| assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) | |||||
| multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) | |||||
| assert len(multiple_all_candidates) == len(multiple_cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||||
| return candidates, min_address_num, address_num | |||||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||||
| candidates = [] | candidates = [] | ||||
| for l in length: | |||||
| X, Y = self.base[l] | |||||
| idx = bisect.bisect_left(Y, key) | |||||
| begin = max(0, idx - 1) | |||||
| end = min(idx + 2, len(X)) | |||||
| for idx in range(begin, end): | |||||
| err = abs(Y[idx] - key) | |||||
| if abs(err - min_err) < 1e-9: | |||||
| candidates.extend(X[idx]) | |||||
| elif err < min_err: | |||||
| candidates = copy.deepcopy(X[idx]) | |||||
| min_err = err | |||||
| return candidates | |||||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||||
| def get_all_candidates(self): | |||||
| return sum([sum(D[0], []) for D in self.base.values()], []) | |||||
| if multiple_predictions: | |||||
| save_pred_res = pred_res | |||||
| pred_res = flatten(pred_res) | |||||
| for c in abduce_c: | |||||
| candidate = pred_res.copy() | |||||
| for i, idx in enumerate(address_idx): | |||||
| candidate[idx] = c[i] | |||||
| if multiple_predictions: | |||||
| candidate = reform_idx(candidate, save_pred_res) | |||||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||||
| candidates.append(candidate) | |||||
| return candidates | |||||
| def _dict_len(self, dic): | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(len(c) for c in dic.values()) | |||||
| def __len__(self): | def __len__(self): | ||||
| return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) | |||||
| if not self.GKB_flag: | |||||
| return 0 | |||||
| else: | |||||
| return sum(self._dict_len(v) for v in self.base.values()) | |||||
| class HWF_KB(RegKB): | |||||
| def __init__( | |||||
| self, GKB_flag=False, | |||||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||||
| len_list=[1, 3, 5, 7], | |||||
| max_err=1e-3 | |||||
| ): | |||||
| super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) | |||||
| def valid_candidate(self, formula): | |||||
| if len(formula) % 2 == 0: | |||||
| return False | |||||
| for i in range(len(formula)): | |||||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||||
| return False | |||||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||||
| return False | |||||
| return True | |||||
| def logic_forward(self, formula): | |||||
| if not self.valid_candidate(formula): | |||||
| return np.inf | |||||
| mapping = { | |||||
| '1': '1', | |||||
| '2': '2', | |||||
| '3': '3', | |||||
| '4': '4', | |||||
| '5': '5', | |||||
| '6': '6', | |||||
| '7': '7', | |||||
| '8': '8', | |||||
| '9': '9', | |||||
| '+': '+', | |||||
| '-': '-', | |||||
| 'times': '*', | |||||
| 'div': '/', | |||||
| } | |||||
| formula = [mapping[f] for f in formula] | |||||
| return round(eval(''.join(formula)), 2) | |||||
| import time | import time | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| t1 = time.time() | t1 = time.time() | ||||
| kb = HWF_KB(True) | |||||
| kb = add_KB(True) | |||||
| t2 = time.time() | t2 = time.time() | ||||
| print(t2 - t1) | print(t2 - t1) | ||||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||||
| # Y = [2, 1, 1, 2, 2] | |||||
| # kb = ClsKB(X, Y) | |||||
| # print('len(kb):', len(kb)) | |||||
| # res = kb.get_candidates(2, 5) | |||||
| # print(res) | |||||
| # res = kb.get_candidates(2, 3) | |||||
| # print(res) | |||||
| # res = kb.get_candidates(None) | |||||
| # print(res) | |||||
| # print() | |||||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||||
| # Y = [2, 1, 1, 2, 1.5, 1.5] | |||||
| # kb = RegKB(X, Y) | |||||
| # print('len(kb):', len(kb)) | |||||
| # res = kb.get_candidates(1.6) | |||||
| # print(res) | |||||
| # res = kb.get_candidates(1.6, length = 9) | |||||
| # print(res) | |||||
| # res = kb.get_candidates(None) | |||||
| # print(res) | |||||