| @@ -31,6 +31,15 @@ class KBBase(ABC): | |||
| @abstractmethod | |||
| def logic_forward(self): | |||
| pass | |||
| def _logic_forward(self, xs, multiple_predictions=False): | |||
| if not multiple_predictions: | |||
| return self.logic_forward(xs) | |||
| else: | |||
| res = [] | |||
| for x in xs: | |||
| res.append(self.logic_forward(x)) | |||
| return res | |||
| @abstractmethod | |||
| def abduce_candidates(self): | |||
| @@ -40,7 +49,7 @@ class KBBase(ABC): | |||
| def address_by_idx(self): | |||
| pass | |||
| def _address(self, address_num, pred_res, key, multiple_predictions=False): | |||
| def _address(self, address_num, pred_res, key, multiple_predictions): | |||
| new_candidates = [] | |||
| if not multiple_predictions: | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| @@ -52,12 +61,12 @@ class KBBase(ABC): | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), key): | |||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| @@ -88,16 +97,14 @@ class ClsKB(KBBase): | |||
| self.GKB_flag = GKB_flag | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.len_list = len_list | |||
| self.max_err = 0 | |||
| if GKB_flag: | |||
| self.base = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| else: | |||
| self.all_address_candidate_dict = {} | |||
| for address_num in range(max(self.len_list) + 1): | |||
| self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| @@ -133,33 +140,57 @@ class ClsKB(KBBase): | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| if self.GKB_flag: | |||
| # TODO: 这里有可能是multiple_predictions吗 | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| else: | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.base == {} or len(pred_res) not in self.len_list: | |||
| return [] | |||
| all_candidates = self.base[len(pred_res)][key] | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| if self.base == {}: | |||
| return [], 0, 0 | |||
| if len(all_candidates) == 0: | |||
| candidates = [] | |||
| min_address_num = 0 | |||
| address_num = 0 | |||
| if not multiple_predictions: | |||
| if len(pred_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self.base[len(pred_res)][key] | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| else: | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| min_address_num = 0 | |||
| all_candidates_save = [] | |||
| cost_list_save = [] | |||
| for p_res, k in zip(pred_res, key): | |||
| if len(p_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self.base[len(p_res)][k] | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| all_candidates_save.append(all_candidates) | |||
| cost_list = hamming_dist(p_res, all_candidates) | |||
| min_address_num += np.min(cost_list) | |||
| cost_list_save.append(cost_list) | |||
| multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] | |||
| assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) | |||
| multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) | |||
| assert len(multiple_all_candidates) == len(multiple_cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| abduce_c = self.all_address_candidate_dict[len(address_idx)] | |||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| @@ -173,7 +204,7 @@ class ClsKB(KBBase): | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| if self.logic_forward(candidate) == key: | |||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| @@ -197,50 +228,13 @@ class add_KB(ClsKB): | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| # TODO:这是个回归任务(对于y而言),在logic_forward加round变成离散的分类任务固然可行,但最好还是用RegKB吧,作为例子示范。还需要对下面的ClsKB进行修改(见TODO) | |||
| class HWF_KB(ClsKB): | |||
| def __init__( | |||
| self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] | |||
| ): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||
| def valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if not self.valid_candidate(formula): | |||
| return np.inf | |||
| mapping = { | |||
| '1': '1', | |||
| '2': '2', | |||
| '3': '3', | |||
| '4': '4', | |||
| '5': '5', | |||
| '6': '6', | |||
| '7': '7', | |||
| '8': '8', | |||
| '9': '9', | |||
| '+': '+', | |||
| '-': '-', | |||
| 'times': '*', | |||
| 'div': '/', | |||
| } | |||
| formula = [mapping[f] for f in formula] | |||
| return round(eval(''.join(formula)), 2) | |||
| class prolog_KB(KBBase): | |||
| def __init__(self, pseudo_label_list): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.prolog = pyswip.Prolog() | |||
| self.max_err = 0 | |||
| def logic_forward(self): | |||
| pass | |||
| @@ -295,11 +289,11 @@ class add_prolog_KB(prolog_KB): | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list=[0, 1, '+', '=']): | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog.consult('../examples/datasets/hed/learn_add.pl') | |||
| self.prolog.consult('./datasets/hed/learn_add.pl') | |||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||
| def logic_forward(self, exs): | |||
| return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 | |||
| return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 | |||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | |||
| # flatten | |||
| @@ -329,93 +323,204 @@ class HED_prolog_KB(prolog_KB): | |||
| rules.append(rule.value) | |||
| return rules | |||
| # def consist_rules(self, pred_res, rules): | |||
| # TODO:这里需要修改一下这个类,原本的RegKB是对GKB而言的,现在需要和ClsKB一样同时支持GKB和非GKB。需要补充非GKB部分(可能继承_abduce_by_search就行),以及修改GKB部分_abduce_by_GKB的逻辑(原本逻辑是找与key最近的y的abduce结果,现在改成与key在一定误差范围内的y的abduce结果) | |||
| # TODO:我理解的RegKB是这样的: | |||
| # TODO:1. 对GKB而言,即_abduce_by_GKB,给定key和length,还需要一个self.max_err,返回所有与key绝对值小于max_err的abduction结果 | |||
| # TODO:比如GKB里的y有[1.3, 1.49, 1.50, 1.52, 1.6],若key=1.5,max_err=1e-5,则返回[y=1.50]的abduction结果;若key=1.5,max_err=0.05,则返回所有[y=1.49, 1.50, 1.52]的abduction结果 | |||
| # TODO:因此在二分查找bisect_left后,需要分别往前和往后遍历,从GKB里找符合误差的y | |||
| # TODO:self.max_err默认值取很小就行,比如HWF这类任务;但有些任务(比如法院刑期预测)的max_err需要大些,因此可以由用户自定义 | |||
| # TODO:2. 对非GKB而言,估计直接用_abduce_by_search就行,check_equal那限定为数字且控制回归误差max_err | |||
| class RegKB(KBBase): | |||
| def __init__(self, GKB_flag=False, X=None, Y=None): | |||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): | |||
| super().__init__() | |||
| tmp_dict = {} | |||
| for x, y in zip(X, Y): | |||
| tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| self.base = {} | |||
| for l in tmp_dict.keys(): | |||
| data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) | |||
| X = [x for y, x in data] | |||
| Y = [y for y, x in data] | |||
| self.base[l] = (X, Y) | |||
| def valid_candidate(self): | |||
| pass | |||
| self.GKB_flag = GKB_flag | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.len_list = len_list | |||
| self.max_err = max_err | |||
| if GKB_flag: | |||
| self.base = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| XY_list = [] | |||
| for post_x in post_x_it: | |||
| x = (pre_x,) + post_x | |||
| y = self.logic_forward(x) | |||
| if y != np.inf: | |||
| XY_list.append((x, y)) | |||
| return XY_list | |||
| # Parallel _get_GKB | |||
| def _get_GKB(self): | |||
| X, Y = [], [] | |||
| for length in self.len_list: | |||
| arg_list = [] | |||
| for pre_x in self.pseudo_label_list: | |||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
| arg_list.append((pre_x, post_x_it)) | |||
| with Pool(processes=len(arg_list)) as pool: | |||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||
| for XY_list in ret_list: | |||
| if len(XY_list) == 0: | |||
| continue | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| return X, Y | |||
| def logic_forward(self): | |||
| pass | |||
| def _abduce_by_GKB(self, key, length=None): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| if self.GKB_flag: | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| else: | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| length = self._length(length) | |||
| def _regression_find_candidate_GKB(self, pred_res, key): | |||
| potential_candidates = self.base[len(pred_res)] | |||
| key_list = sorted(potential_candidates) | |||
| key_idx = bisect.bisect_left(key_list, key) | |||
| all_candidates = [] | |||
| for idx in range(key_idx - 1, 0, -1): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| for idx in range(key_idx, len(key_list)): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| return all_candidates | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| if self.base == {}: | |||
| return [], 0, 0 | |||
| min_err = 999999 | |||
| if not multiple_predictions: | |||
| if len(pred_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self._regression_find_candidate_GKB(pred_res, key) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| else: | |||
| min_address_num = 0 | |||
| all_candidates_save = [] | |||
| cost_list_save = [] | |||
| for p_res, k in zip(pred_res, key): | |||
| if len(p_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self._regression_find_candidate_GKB(p_res, k) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| all_candidates_save.append(all_candidates) | |||
| cost_list = hamming_dist(p_res, all_candidates) | |||
| min_address_num += np.min(cost_list) | |||
| cost_list_save.append(cost_list) | |||
| multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] | |||
| assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) | |||
| multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) | |||
| assert len(multiple_all_candidates) == len(multiple_cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| for l in length: | |||
| X, Y = self.base[l] | |||
| idx = bisect.bisect_left(Y, key) | |||
| begin = max(0, idx - 1) | |||
| end = min(idx + 2, len(X)) | |||
| for idx in range(begin, end): | |||
| err = abs(Y[idx] - key) | |||
| if abs(err - min_err) < 1e-9: | |||
| candidates.extend(X[idx]) | |||
| elif err < min_err: | |||
| candidates = copy.deepcopy(X[idx]) | |||
| min_err = err | |||
| return candidates | |||
| abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) | |||
| def get_all_candidates(self): | |||
| return sum([sum(D[0], []) for D in self.base.values()], []) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = flatten(pred_res) | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class HWF_KB(RegKB): | |||
| def __init__( | |||
| self, GKB_flag=False, | |||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||
| len_list=[1, 3, 5, 7], | |||
| max_err=1e-3 | |||
| ): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) | |||
| def valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if not self.valid_candidate(formula): | |||
| return np.inf | |||
| mapping = { | |||
| '1': '1', | |||
| '2': '2', | |||
| '3': '3', | |||
| '4': '4', | |||
| '5': '5', | |||
| '6': '6', | |||
| '7': '7', | |||
| '8': '8', | |||
| '9': '9', | |||
| '+': '+', | |||
| '-': '-', | |||
| 'times': '*', | |||
| 'div': '/', | |||
| } | |||
| formula = [mapping[f] for f in formula] | |||
| return round(eval(''.join(formula)), 2) | |||
| import time | |||
| if __name__ == "__main__": | |||
| t1 = time.time() | |||
| kb = HWF_KB(True) | |||
| kb = add_KB(True) | |||
| t2 = time.time() | |||
| print(t2 - t1) | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
| # Y = [2, 1, 1, 2, 2] | |||
| # kb = ClsKB(X, Y) | |||
| # print('len(kb):', len(kb)) | |||
| # res = kb.get_candidates(2, 5) | |||
| # print(res) | |||
| # res = kb.get_candidates(2, 3) | |||
| # print(res) | |||
| # res = kb.get_candidates(None) | |||
| # print(res) | |||
| # print() | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
| # Y = [2, 1, 1, 2, 1.5, 1.5] | |||
| # kb = RegKB(X, Y) | |||
| # print('len(kb):', len(kb)) | |||
| # res = kb.get_candidates(1.6) | |||
| # print(res) | |||
| # res = kb.get_candidates(1.6, length = 9) | |||
| # print(res) | |||
| # res = kb.get_candidates(None) | |||
| # print(res) | |||