From 7e4102c7091489e97648ff1f6b005f8d89f960c8 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:22:36 +0800 Subject: [PATCH] Update kb.py --- abl/abducer/kb.py | 373 +++++++++++++++++++++++++++++----------------- 1 file changed, 239 insertions(+), 134 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 9a66922..66a134f 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -31,6 +31,15 @@ class KBBase(ABC): @abstractmethod def logic_forward(self): pass + + def _logic_forward(self, xs, multiple_predictions=False): + if not multiple_predictions: + return self.logic_forward(xs) + else: + res = [] + for x in xs: + res.append(self.logic_forward(x)) + return res @abstractmethod def abduce_candidates(self): @@ -40,7 +49,7 @@ class KBBase(ABC): def address_by_idx(self): pass - def _address(self, address_num, pred_res, key, multiple_predictions=False): + def _address(self, address_num, pred_res, key, multiple_predictions): new_candidates = [] if not multiple_predictions: address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) @@ -52,12 +61,12 @@ class KBBase(ABC): new_candidates += candidates return new_candidates - def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): + def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): candidates = [] for address_num in range(len(flatten(pred_res)) + 1): if address_num == 0: - if check_equal(self.logic_forward(pred_res), key): + if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): candidates.append(pred_res) else: new_candidates = self._address(address_num, pred_res, key, multiple_predictions) @@ -88,16 +97,14 @@ class ClsKB(KBBase): self.GKB_flag = GKB_flag self.pseudo_label_list = pseudo_label_list self.len_list = len_list + self.max_err = 0 if GKB_flag: self.base = {} X, Y = self._get_GKB() for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(x) - else: - self.all_address_candidate_dict = {} - for address_num in range(max(self.len_list) + 1): - self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) + # For parallel version of _get_GKB def _get_XY_list(self, args): @@ -133,33 +140,57 @@ class ClsKB(KBBase): def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: - # TODO: 这里有可能是multiple_predictions吗 - return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) + return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) else: return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): - if self.base == {} or len(pred_res) not in self.len_list: - return [] - - all_candidates = self.base[len(pred_res)][key] + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + if self.base == {}: + return [], 0, 0 - if len(all_candidates) == 0: - candidates = [] - min_address_num = 0 - address_num = 0 + if not multiple_predictions: + if len(pred_res) not in self.len_list: + return [], 0, 0 + all_candidates = self.base[len(pred_res)][key] + if len(all_candidates) == 0: + return [], 0, 0 + else: + cost_list = hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates, min_address_num, address_num + else: - cost_list = hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) + min_address_num = 0 + all_candidates_save = [] + cost_list_save = [] + + for p_res, k in zip(pred_res, key): + if len(p_res) not in self.len_list: + return [], 0, 0 + all_candidates = self.base[len(p_res)][k] + if len(all_candidates) == 0: + return [], 0, 0 + else: + all_candidates_save.append(all_candidates) + cost_list = hamming_dist(p_res, all_candidates) + min_address_num += np.min(cost_list) + cost_list_save.append(cost_list) + + multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] + assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) + multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) + assert len(multiple_all_candidates) == len(multiple_cost_list) address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - - return candidates, min_address_num, address_num + idxs = np.where(multiple_cost_list <= address_num)[0] + candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] + return candidates, min_address_num, address_num def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - abduce_c = self.all_address_candidate_dict[len(address_idx)] + abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) if multiple_predictions: save_pred_res = pred_res @@ -173,7 +204,7 @@ class ClsKB(KBBase): if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - if self.logic_forward(candidate) == key: + if check_equal(self._logic_forward(candidate, multiple_predictions), key): candidates.append(candidate) return candidates @@ -197,50 +228,13 @@ class add_KB(ClsKB): def logic_forward(self, nums): return sum(nums) -# TODO:这是个回归任务(对于y而言),在logic_forward加round变成离散的分类任务固然可行,但最好还是用RegKB吧,作为例子示范。还需要对下面的ClsKB进行修改(见TODO) -class HWF_KB(ClsKB): - def __init__( - self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] - ): - super().__init__(GKB_flag, pseudo_label_list, len_list) - - def valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: - return False - if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: - return False - return True - - def logic_forward(self, formula): - if not self.valid_candidate(formula): - return np.inf - mapping = { - '1': '1', - '2': '2', - '3': '3', - '4': '4', - '5': '5', - '6': '6', - '7': '7', - '8': '8', - '9': '9', - '+': '+', - '-': '-', - 'times': '*', - 'div': '/', - } - formula = [mapping[f] for f in formula] - return round(eval(''.join(formula)), 2) - class prolog_KB(KBBase): def __init__(self, pseudo_label_list): super().__init__() self.pseudo_label_list = pseudo_label_list self.prolog = pyswip.Prolog() + self.max_err = 0 def logic_forward(self): pass @@ -295,11 +289,11 @@ class add_prolog_KB(prolog_KB): class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list=[0, 1, '+', '=']): super().__init__(pseudo_label_list) - self.prolog.consult('../examples/datasets/hed/learn_add.pl') + self.prolog.consult('./datasets/hed/learn_add.pl') # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` def logic_forward(self, exs): - return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 + return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 def get_query_string_need_flatten(self, pred_res, key, address_idx): # flatten @@ -329,93 +323,204 @@ class HED_prolog_KB(prolog_KB): rules.append(rule.value) return rules - # def consist_rules(self, pred_res, rules): -# TODO:这里需要修改一下这个类,原本的RegKB是对GKB而言的,现在需要和ClsKB一样同时支持GKB和非GKB。需要补充非GKB部分(可能继承_abduce_by_search就行),以及修改GKB部分_abduce_by_GKB的逻辑(原本逻辑是找与key最近的y的abduce结果,现在改成与key在一定误差范围内的y的abduce结果) -# TODO:我理解的RegKB是这样的: -# TODO:1. 对GKB而言,即_abduce_by_GKB,给定key和length,还需要一个self.max_err,返回所有与key绝对值小于max_err的abduction结果 -# TODO:比如GKB里的y有[1.3, 1.49, 1.50, 1.52, 1.6],若key=1.5,max_err=1e-5,则返回[y=1.50]的abduction结果;若key=1.5,max_err=0.05,则返回所有[y=1.49, 1.50, 1.52]的abduction结果 -# TODO:因此在二分查找bisect_left后,需要分别往前和往后遍历,从GKB里找符合误差的y -# TODO:self.max_err默认值取很小就行,比如HWF这类任务;但有些任务(比如法院刑期预测)的max_err需要大些,因此可以由用户自定义 -# TODO:2. 对非GKB而言,估计直接用_abduce_by_search就行,check_equal那限定为数字且控制回归误差max_err class RegKB(KBBase): - def __init__(self, GKB_flag=False, X=None, Y=None): + def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): super().__init__() - tmp_dict = {} - for x, y in zip(X, Y): - tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) - - self.base = {} - for l in tmp_dict.keys(): - data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) - X = [x for y, x in data] - Y = [y for y, x in data] - self.base[l] = (X, Y) - - def valid_candidate(self): - pass + self.GKB_flag = GKB_flag + self.pseudo_label_list = pseudo_label_list + self.len_list = len_list + self.max_err = max_err + + if GKB_flag: + self.base = {} + X, Y = self._get_GKB() + for x, y in zip(X, Y): + self.base.setdefault(len(x), defaultdict(list))[y].append(x) + + # For parallel version of _get_GKB + def _get_XY_list(self, args): + pre_x, post_x_it = args[0], args[1] + XY_list = [] + for post_x in post_x_it: + x = (pre_x,) + post_x + y = self.logic_forward(x) + if y != np.inf: + XY_list.append((x, y)) + return XY_list + + # Parallel _get_GKB + def _get_GKB(self): + X, Y = [], [] + for length in self.len_list: + arg_list = [] + for pre_x in self.pseudo_label_list: + post_x_it = product(self.pseudo_label_list, repeat=length - 1) + arg_list.append((pre_x, post_x_it)) + with Pool(processes=len(arg_list)) as pool: + ret_list = pool.map(self._get_XY_list, arg_list) + for XY_list in ret_list: + if len(XY_list) == 0: + continue + part_X, part_Y = zip(*XY_list) + X.extend(part_X) + Y.extend(part_Y) + return X, Y def logic_forward(self): pass - def _abduce_by_GKB(self, key, length=None): - if key is None: - return self.get_all_candidates() + def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): + if self.GKB_flag: + return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) + else: + return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - length = self._length(length) + def _regression_find_candidate_GKB(self, pred_res, key): + potential_candidates = self.base[len(pred_res)] + key_list = sorted(potential_candidates) + key_idx = bisect.bisect_left(key_list, key) + + all_candidates = [] + for idx in range(key_idx - 1, 0, -1): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + + for idx in range(key_idx, len(key_list)): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + return all_candidates + + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + if self.base == {}: + return [], 0, 0 - min_err = 999999 + if not multiple_predictions: + if len(pred_res) not in self.len_list: + return [], 0, 0 + all_candidates = self._regression_find_candidate_GKB(pred_res, key) + if len(all_candidates) == 0: + return [], 0, 0 + else: + cost_list = hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates, min_address_num, address_num + + else: + min_address_num = 0 + all_candidates_save = [] + cost_list_save = [] + + for p_res, k in zip(pred_res, key): + if len(p_res) not in self.len_list: + return [], 0, 0 + all_candidates = self._regression_find_candidate_GKB(p_res, k) + if len(all_candidates) == 0: + return [], 0, 0 + else: + all_candidates_save.append(all_candidates) + cost_list = hamming_dist(p_res, all_candidates) + min_address_num += np.min(cost_list) + cost_list_save.append(cost_list) + + multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] + assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) + multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) + assert len(multiple_all_candidates) == len(multiple_cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(multiple_cost_list <= address_num)[0] + candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] + return candidates, min_address_num, address_num + + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - for l in length: - X, Y = self.base[l] - - idx = bisect.bisect_left(Y, key) - begin = max(0, idx - 1) - end = min(idx + 2, len(X)) - - for idx in range(begin, end): - err = abs(Y[idx] - key) - if abs(err - min_err) < 1e-9: - candidates.extend(X[idx]) - elif err < min_err: - candidates = copy.deepcopy(X[idx]) - min_err = err - return candidates + abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) - def get_all_candidates(self): - return sum([sum(D[0], []) for D in self.base.values()], []) + if multiple_predictions: + save_pred_res = pred_res + pred_res = flatten(pred_res) + + for c in abduce_c: + candidate = pred_res.copy() + for i, idx in enumerate(address_idx): + candidate[idx] = c[i] + + if multiple_predictions: + candidate = reform_idx(candidate, save_pred_res) + + if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): + candidates.append(candidate) + return candidates + + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) def __len__(self): - return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) + + +class HWF_KB(RegKB): + def __init__( + self, GKB_flag=False, + pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], + len_list=[1, 3, 5, 7], + max_err=1e-3 + ): + super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) + + def valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: + return False + if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: + return False + return True + + def logic_forward(self, formula): + if not self.valid_candidate(formula): + return np.inf + mapping = { + '1': '1', + '2': '2', + '3': '3', + '4': '4', + '5': '5', + '6': '6', + '7': '7', + '8': '8', + '9': '9', + '+': '+', + '-': '-', + 'times': '*', + 'div': '/', + } + formula = [mapping[f] for f in formula] + return round(eval(''.join(formula)), 2) import time if __name__ == "__main__": t1 = time.time() - kb = HWF_KB(True) + kb = add_KB(True) t2 = time.time() print(t2 - t1) - # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] - # Y = [2, 1, 1, 2, 2] - # kb = ClsKB(X, Y) - # print('len(kb):', len(kb)) - # res = kb.get_candidates(2, 5) - # print(res) - # res = kb.get_candidates(2, 3) - # print(res) - # res = kb.get_candidates(None) - # print(res) - # print() - - # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] - # Y = [2, 1, 1, 2, 1.5, 1.5] - # kb = RegKB(X, Y) - # print('len(kb):', len(kb)) - # res = kb.get_candidates(1.6) - # print(res) - # res = kb.get_candidates(1.6, length = 9) - # print(res) - # res = kb.get_candidates(None) - # print(res) +