diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 17a27ea..efdb241 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -2,25 +2,42 @@ import abc import numpy as np from multiprocessing import Pool from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, float_parameter +from ..utils.utils import ( + confidence_dist, + flatten, + reform_idx, + hamming_dist, + float_parameter, +) + class ReasonerBase(abc.ABC): - def __init__(self, kb, dist_func='hamming', zoopt=False): + def __init__(self, kb, dist_func="hamming", mapping=None, zoopt=False): + if not (dist_func == "hamming" or dist_func == "confidence"): + raise NotImplementedError + self.kb = kb - assert dist_func == 'hamming' or dist_func == 'confidence' self.dist_func = dist_func self.zoopt = zoopt - if dist_func == 'confidence': - self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) + if mapping is None: + self.mapping = dict( + zip( + list(range(len(self.kb.pseudo_label_list))), + self.kb.pseudo_label_list, + ) + ) + else: + self.mapping = mapping + self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) - def _get_cost_list(self, pred_res, pred_res_prob, candidates): + def _get_cost_list(self, pseudo_label, pred_res_prob, candidates): """ Get the cost list of candidates based on the distance function. Parameters ---------- - pred_res : list - The predicted result. + pseudo_label : list + List of predicted pseudo labels. pred_res_prob : list The predicted result probability. candidates : list @@ -31,21 +48,21 @@ class ReasonerBase(abc.ABC): list The cost list of candidates. """ - if self.dist_func == 'hamming': - return hamming_dist(pred_res, candidates) - - elif self.dist_func == 'confidence': - candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates] + if self.dist_func == "hamming": + return hamming_dist(pseudo_label, candidates) + + elif self.dist_func == "confidence": + candidates = [list(map(lambda x: self.remapping[x], c)) for c in candidates] return confidence_dist(pred_res_prob, candidates) - def _get_one_candidate(self, pred_res, pred_res_prob, candidates): + def _get_one_candidate(self, pseudo_label, pred_res_prob, candidates): """ Get the best candidate based on the distance function. Parameters ---------- - pred_res : list - The predicted result. + pseudo_label : list + List of predicted pseudo labels. pred_res_prob : list The predicted result probability. candidates : list @@ -58,23 +75,14 @@ class ReasonerBase(abc.ABC): """ if len(candidates) == 0: return [] - elif len(candidates) == 1 or self.zoopt: + elif len(candidates) == 1: return candidates[0] - else: - cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) + cost_list = self._get_cost_list(pseudo_label, pred_res_prob, candidates) candidate = candidates[np.argmin(cost_list)] return candidate - - def _zoopt_revision_score_single(self, sol_x, pred_res, pred_res_prob, y): - revision_idx = np.where(sol_x != 0)[0] - candidates = self.revise_by_idx(pred_res, y, revision_idx) - if len(candidates) > 0: - return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) - else: - return len(pred_res) - - def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): + + def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, y, sol): """ Get the revision score for a single solution. @@ -84,6 +92,8 @@ class ReasonerBase(abc.ABC): Solution to evaluate. pred_res : list List of predicted results. + pseudo_label : list + List of predicted pseudo labels. pred_res_prob : list List of probabilities for predicted results. y : str @@ -95,23 +105,27 @@ class ReasonerBase(abc.ABC): The revision score for the given solution. """ revision_idx = np.where(sol.get_x() != 0)[0] - candidates = self.revise_by_idx(pred_res, y, revision_idx) + candidates = self.revise_by_idx(pseudo_label, y, revision_idx) if len(candidates) > 0: - return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) + return np.min(self._get_cost_list(pseudo_label, pred_res_prob, candidates)) else: return len(pred_res) - + def _constrain_revision_num(self, solution, max_revision_num): x = solution.get_x() return max_revision_num - x.sum() - def zoopt_get_solution(self, pred_res, pred_res_prob, y, max_revision_num): + def zoopt_get_solution( + self, pred_res, pseudo_label, pred_res_prob, y, max_revision_num + ): """Get the optimal solution using the Zoopt library. Parameters ---------- pred_res : list List of predicted results. + pseudo_label : list + List of predicted pseudo labels. pred_res_prob : list List of probabilities for predicted results. y : str @@ -127,21 +141,23 @@ class ReasonerBase(abc.ABC): length = len(flatten(pred_res)) dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length) objective = Objective( - lambda sol: self.zoopt_revision_score(pred_res, pred_res_prob, y, sol), + lambda sol: self.zoopt_revision_score( + pred_res, pseudo_label, pred_res_prob, y, sol + ), dim=dimension, constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), ) parameter = Parameter(budget=100, intermediate_result=False, autoset=True) solution = Opt.min(objective, parameter).get_x() return solution - - def revise_by_idx(self, pred_res, y, revision_idx): + + def revise_by_idx(self, pseudo_label, y, revision_idx): """Get the revisions corresponding to the given indices. Parameters ---------- - pred_res : list - List of predicted results. + pseudo_label : list + List of predicted pseudo labels. y : str Ground truth for the predicted results. revision_idx : array-like @@ -152,7 +168,7 @@ class ReasonerBase(abc.ABC): list The revisions corresponding to the given indices. """ - return self.kb.revise_by_idx(pred_res, y, revision_idx) + return self.kb.revise_by_idx(pseudo_label, y, revision_idx) def abduce(self, data, max_revision=-1, require_more_revision=0): """Perform abduction on the given data. @@ -162,7 +178,7 @@ class ReasonerBase(abc.ABC): data : tuple Tuple containing the predicted results, predicted result probabilities, and y. max_revision : int or float, optional - Maximum number of revisions to use. If float, represents the fraction of total revisions to use. + Maximum number of revisions to use. If float, represents the fraction of total revisions to use. If -1, use all revisions. Defaults to -1. require_more_revision : int, optional Number of additional revisions to require. Defaults to 0. @@ -173,16 +189,22 @@ class ReasonerBase(abc.ABC): The abduced revisions. """ pred_res, pred_res_prob, y = data + pseudo_label = [self.mapping[_idx] for _idx in pred_res] + max_revision_num = float_parameter(max_revision, len(flatten(pred_res))) if self.zoopt: - solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_revision_num) + solution = self.zoopt_get_solution( + pred_res, pseudo_label, pred_res_prob, y, max_revision_num + ) revision_idx = np.where(solution != 0)[0] - candidates = self.revise_by_idx(pred_res, y, revision_idx) + candidates = self.revise_by_idx(pseudo_label, y, revision_idx) else: - candidates = self.kb.abduce_candidates(pred_res, y, max_revision_num, require_more_revision) + candidates = self.kb.abduce_candidates( + pred_res, y, max_revision_num, require_more_revision + ) - candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) + candidate = self._get_one_candidate(pseudo_label, pred_res_prob, candidates) return candidate def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0): @@ -195,7 +217,7 @@ class ReasonerBase(abc.ABC): Y : list List of ground truths. max_revision : int or float, optional - Maximum number of revisions to use. If float, represents the fraction of total revisions to use. + Maximum number of revisions to use. If float, represents the fraction of total revisions to use. If -1, use all revisions. Defaults to -1. require_more_revision : int, optional Number of additional revisions to require. Defaults to 0. @@ -205,8 +227,11 @@ class ReasonerBase(abc.ABC): list The abduced revisions. """ - return [self.abduce((z, prob, y), max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] - + return [ + self.abduce((z, prob, y), max_revision, require_more_revision) + for z, prob, y in zip(Z["label"], Z["prob"], Y) + ] + # def _batch_abduce_helper(self, args): # z, prob, y, max_revision, require_more_revision = args # return self.abduce((z, prob, y), max_revision, require_more_revision) @@ -215,120 +240,224 @@ class ReasonerBase(abc.ABC): # with Pool(processes=os.cpu_count()) as pool: # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) # return results - + def __call__(self, Z, Y, max_revision=-1, require_more_revision=0): return self.batch_abduce(Z, Y, max_revision, require_more_revision) - -if __name__ == '__main__': + +if __name__ == "__main__": from kb import KBBase, prolog_KB - - prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] - prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + + prob1 = [ + [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + ] + prob2 = [ + [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + ] class add_KB(KBBase): - def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True): + def __init__( + self, + pseudo_label_list=list(range(10)), + len_list=[2], + GKB_flag=False, + max_err=0, + use_cache=True, + ): super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) def logic_forward(self, nums): return sum(nums) - - - print('add_KB with GKB:') + + print("add_KB with GKB:") kb = add_KB(GKB_flag=True) - reasoner = ReasonerBase(kb, 'confidence') - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) + reasoner = ReasonerBase(kb, "confidence") + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 + ) print(res) print() - - print('add_KB without GKB:') + + print("add_KB without GKB:") kb = add_KB() - reasoner = ReasonerBase(kb, 'confidence') - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) + reasoner = ReasonerBase(kb, "confidence") + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 + ) print(res) print() - - print('add_KB without GKB:, no cache') + + print("add_KB without GKB:, no cache") kb = add_KB(use_cache=False) - reasoner = ReasonerBase(kb, 'confidence') - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) + reasoner = ReasonerBase(kb, "confidence") + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 + ) print(res) print() - - print('prolog_KB with add.pl:') - kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl') - reasoner = ReasonerBase(kb, 'confidence') - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) + + print("prolog_KB with add.pl:") + kb = prolog_KB( + pseudo_label_list=list(range(10)), + pl_file="../examples/mnist_add/datasets/add.pl", + ) + reasoner = ReasonerBase(kb, "confidence") + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 + ) print(res) print() - print('prolog_KB with add.pl using zoopt:') - kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl') - reasoner = ReasonerBase(kb, 'confidence', zoopt=True) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) + print("prolog_KB with add.pl using zoopt:") + kb = prolog_KB( + pseudo_label_list=list(range(10)), + pl_file="../examples/mnist_add/datasets/add.pl", + ) + reasoner = ReasonerBase(kb, "confidence", zoopt=True) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 + ) print(res) print() - - print('add_KB with multiple inputs at once:') - multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], - [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] - + + print("add_KB with multiple inputs at once:") + multiple_prob = [ + [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ], + [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ], + ] + kb = add_KB() - reasoner = ReasonerBase(kb, 'confidence') - res = reasoner.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=1) + reasoner = ReasonerBase(kb, "confidence") + res = reasoner.batch_abduce( + {"cls": [[1, 1], [1, 2]], "prob": multiple_prob}, + [4, 8], + max_revision=2, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [[1, 1], [1, 2]], "prob": multiple_prob}, + [4, 8], + max_revision=2, + require_more_revision=1, + ) print(res) print() - + class HWF_KB(KBBase): def __init__( - self, - pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], + self, + pseudo_label_list=[ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "+", + "-", + "times", + "div", + ], len_list=[1, 3, 5, 7], GKB_flag=False, max_err=1e-3, - use_cache=True + use_cache=True, ): super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) @@ -336,9 +465,19 @@ if __name__ == '__main__': if len(formula) % 2 == 0: return False for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: + if i % 2 == 0 and formula[i] not in [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ]: return False - if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: return False return True @@ -346,91 +485,183 @@ if __name__ == '__main__': if not self._valid_candidate(formula): return np.inf mapping = {str(i): str(i) for i in range(1, 10)} - mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'}) + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) formula = [mapping[f] for f in formula] - return eval(''.join(formula)) - - print('HWF_KB with GKB, max_err=0.1') - kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) - reasoner = ReasonerBase(kb, 'hamming') - res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) + return eval("".join(formula)) + + print("HWF_KB with GKB, max_err=0.1") + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err=0.1) + reasoner = ReasonerBase(kb, "hamming") + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"]], "prob": [None]}, + [3], + max_revision=2, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "+", "9"]], "prob": [None]}, + [65], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, + [3.17], + max_revision=5, + require_more_revision=3, + ) print(res) print() - - print('HWF_KB without GKB, max_err=0.1') - kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) - reasoner = ReasonerBase(kb, 'hamming') - res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) + + print("HWF_KB without GKB, max_err=0.1") + kb = HWF_KB(len_list=[1, 3, 5], max_err=0.1) + reasoner = ReasonerBase(kb, "hamming") + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"]], "prob": [None]}, + [3], + max_revision=2, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "+", "9"]], "prob": [None]}, + [65], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, + [3.17], + max_revision=5, + require_more_revision=3, + ) print(res) print() - - print('HWF_KB with GKB, max_err=1') - kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) - reasoner = ReasonerBase(kb, 'hamming') - res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) + + print("HWF_KB with GKB, max_err=1") + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err=1) + reasoner = ReasonerBase(kb, "hamming") + res = reasoner.batch_abduce( + {"cls": [["5", "+", "9"]], "prob": [None]}, + [65], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"]], "prob": [None]}, + [1.67], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, + [3.17], + max_revision=5, + require_more_revision=3, + ) print(res) print() - - print('HWF_KB without GKB, max_err=1') - kb = HWF_KB(len_list=[1, 3, 5], max_err = 1) - reasoner = ReasonerBase(kb, 'hamming') - res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) + + print("HWF_KB without GKB, max_err=1") + kb = HWF_KB(len_list=[1, 3, 5], max_err=1) + reasoner = ReasonerBase(kb, "hamming") + res = reasoner.batch_abduce( + {"cls": [["5", "+", "9"]], "prob": [None]}, + [65], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"]], "prob": [None]}, + [1.67], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, + [3.17], + max_revision=5, + require_more_revision=3, + ) print(res) print() - - print('HWF_KB with multiple inputs at once:') - kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) - reasoner = ReasonerBase(kb, 'hamming') - res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=1, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=3, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 65], max_revision=3, require_more_revision=0) + + print("HWF_KB with multiple inputs at once:") + kb = HWF_KB(len_list=[1, 3, 5], max_err=0.1) + reasoner = ReasonerBase(kb, "hamming") + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, + [3, 64], + max_revision=1, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, + [3, 64], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, + [3, 65], + max_revision=3, + require_more_revision=0, + ) print(res) print() - print('max_revision is float') - res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.5, require_more_revision=0) - print(res) - res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.9, require_more_revision=0) + print("max_revision is float") + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, + [3, 64], + max_revision=0.5, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, + [3, 64], + max_revision=0.9, + require_more_revision=0, + ) print(res) print() - + class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list, pl_file) - + def consist_rule(self, exs, rules): - rules = str(rules).replace("\'","") - return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 + rules = str(rules).replace("'", "") + return ( + len( + list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules))) + ) + != 0 + ) def abduce_rules(self, pred_res): - prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) + prolog_result = list( + self.prolog.query("consistent_inst_feature(%s, X)." % pred_res) + ) if len(prolog_result) == 0: return None - prolog_rules = prolog_result[0]['X'] + prolog_rules = prolog_result[0]["X"] rules = [rule.value for rule in prolog_rules] return rules - + class HED_Reasoner(ReasonerBase): - def __init__(self, kb, dist_func='hamming'): + def __init__(self, kb, dist_func="hamming"): super().__init__(kb, dist_func, zoopt=True) - + def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs): pred = [] k = [] @@ -439,14 +670,14 @@ if __name__ == '__main__': pred.append(pred_res[idx]) k.append(y[idx]) revision_flag += list(all_revision_flag[idx]) - revision_idx = np.where(np.array(revision_flag) != 0)[0] + revision_idx = np.where(np.array(revision_flag) != 0)[0] candidate = self.revise_by_idx(pred, k, revision_idx) return candidate - - def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): + + def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): all_revision_flag = reform_idx(sol.get_x(), pred_res) lefted_idxs = [i for i in range(len(pred_res))] - candidate_size = [] + candidate_size = [] while lefted_idxs: idxs = [] idxs.append(lefted_idxs.pop(0)) @@ -455,21 +686,26 @@ if __name__ == '__main__': for idx in range(-1, len(pred_res)): if (not idx in idxs) and (idx >= 0): idxs.append(idx) - candidate = self._revise_by_idxs(pred_res, y, all_revision_flag, idxs) + candidate = self._revise_by_idxs( + pred_res, y, all_revision_flag, idxs + ) if len(candidate) == 0: if len(idxs) > 1: idxs.pop() else: if len(idxs) > len(max_candidate_idxs): found = True - max_candidate_idxs = idxs.copy() + max_candidate_idxs = idxs.copy() removed = [i for i in lefted_idxs if i in max_candidate_idxs] if found: candidate_size.append(len(removed) + 1) - lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] + lefted_idxs = [ + i for i in lefted_idxs if i not in max_candidate_idxs + ] candidate_size.sort() score = 0 import math + for i in range(0, len(candidate_size)): score -= math.exp(-i) * candidate_size[i] return score @@ -477,31 +713,49 @@ if __name__ == '__main__': def abduce_rules(self, pred_res): return self.kb.abduce_rules(pred_res) - kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/hed/datasets/learn_add.pl') + kb = HED_prolog_KB( + pseudo_label_list=[1, 0, "+", "="], + pl_file="../examples/hed/datasets/learn_add.pl", + ) reasoner = HED_Reasoner(kb) - consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] - inconsist_exs1 = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0], [0, '+', 0, '=', 1]] - inconsist_exs2 = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] - rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])'] - - print('HED_kb logic forward') + consist_exs = [ + [1, 1, "+", 0, "=", 1, 1], + [1, "+", 1, "=", 1, 0], + [0, "+", 0, "=", 0], + ] + inconsist_exs1 = [ + [1, 1, "+", 0, "=", 1, 1], + [1, "+", 1, "=", 1, 0], + [0, "+", 0, "=", 0], + [0, "+", 0, "=", 1], + ] + inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] + rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] + + print("HED_kb logic forward") print(kb.logic_forward(consist_exs)) print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) print() - print('HED_kb consist rule') - print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules)) - print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) + print("HED_kb consist rule") + print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) + print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) print() - print('HED_Reasoner abduce') - res = reasoner.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))) + print("HED_Reasoner abduce") + res = reasoner.abduce( + (consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)) + ) print(res) - res = reasoner.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1))) + res = reasoner.abduce( + (inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)) + ) print(res) - res = reasoner.abduce((inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2))) + res = reasoner.abduce( + (inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2)) + ) print(res) print() - print('HED_Reasoner abduce rules') + print("HED_Reasoner abduce rules") abduced_rules = reasoner.abduce_rules(consist_exs) - print(abduced_rules) \ No newline at end of file + print(abduced_rules)