| @@ -2,25 +2,42 @@ import abc | |||
| import numpy as np | |||
| from multiprocessing import Pool | |||
| from zoopt import Dimension, Objective, Parameter, Opt | |||
| from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, float_parameter | |||
| from ..utils.utils import ( | |||
| confidence_dist, | |||
| flatten, | |||
| reform_idx, | |||
| hamming_dist, | |||
| float_parameter, | |||
| ) | |||
| class ReasonerBase(abc.ABC): | |||
| def __init__(self, kb, dist_func='hamming', zoopt=False): | |||
| def __init__(self, kb, dist_func="hamming", mapping=None, zoopt=False): | |||
| if not (dist_func == "hamming" or dist_func == "confidence"): | |||
| raise NotImplementedError | |||
| self.kb = kb | |||
| assert dist_func == 'hamming' or dist_func == 'confidence' | |||
| self.dist_func = dist_func | |||
| self.zoopt = zoopt | |||
| if dist_func == 'confidence': | |||
| self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) | |||
| if mapping is None: | |||
| self.mapping = dict( | |||
| zip( | |||
| list(range(len(self.kb.pseudo_label_list))), | |||
| self.kb.pseudo_label_list, | |||
| ) | |||
| ) | |||
| else: | |||
| self.mapping = mapping | |||
| self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||
| def _get_cost_list(self, pred_res, pred_res_prob, candidates): | |||
| def _get_cost_list(self, pseudo_label, pred_res_prob, candidates): | |||
| """ | |||
| Get the cost list of candidates based on the distance function. | |||
| Parameters | |||
| ---------- | |||
| pred_res : list | |||
| The predicted result. | |||
| pseudo_label : list | |||
| List of predicted pseudo labels. | |||
| pred_res_prob : list | |||
| The predicted result probability. | |||
| candidates : list | |||
| @@ -31,21 +48,21 @@ class ReasonerBase(abc.ABC): | |||
| list | |||
| The cost list of candidates. | |||
| """ | |||
| if self.dist_func == 'hamming': | |||
| return hamming_dist(pred_res, candidates) | |||
| elif self.dist_func == 'confidence': | |||
| candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates] | |||
| if self.dist_func == "hamming": | |||
| return hamming_dist(pseudo_label, candidates) | |||
| elif self.dist_func == "confidence": | |||
| candidates = [list(map(lambda x: self.remapping[x], c)) for c in candidates] | |||
| return confidence_dist(pred_res_prob, candidates) | |||
| def _get_one_candidate(self, pred_res, pred_res_prob, candidates): | |||
| def _get_one_candidate(self, pseudo_label, pred_res_prob, candidates): | |||
| """ | |||
| Get the best candidate based on the distance function. | |||
| Parameters | |||
| ---------- | |||
| pred_res : list | |||
| The predicted result. | |||
| pseudo_label : list | |||
| List of predicted pseudo labels. | |||
| pred_res_prob : list | |||
| The predicted result probability. | |||
| candidates : list | |||
| @@ -58,23 +75,14 @@ class ReasonerBase(abc.ABC): | |||
| """ | |||
| if len(candidates) == 0: | |||
| return [] | |||
| elif len(candidates) == 1 or self.zoopt: | |||
| elif len(candidates) == 1: | |||
| return candidates[0] | |||
| else: | |||
| cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) | |||
| cost_list = self._get_cost_list(pseudo_label, pred_res_prob, candidates) | |||
| candidate = candidates[np.argmin(cost_list)] | |||
| return candidate | |||
| def _zoopt_revision_score_single(self, sol_x, pred_res, pred_res_prob, y): | |||
| revision_idx = np.where(sol_x != 0)[0] | |||
| candidates = self.revise_by_idx(pred_res, y, revision_idx) | |||
| if len(candidates) > 0: | |||
| return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) | |||
| else: | |||
| return len(pred_res) | |||
| def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): | |||
| def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, y, sol): | |||
| """ | |||
| Get the revision score for a single solution. | |||
| @@ -84,6 +92,8 @@ class ReasonerBase(abc.ABC): | |||
| Solution to evaluate. | |||
| pred_res : list | |||
| List of predicted results. | |||
| pseudo_label : list | |||
| List of predicted pseudo labels. | |||
| pred_res_prob : list | |||
| List of probabilities for predicted results. | |||
| y : str | |||
| @@ -95,23 +105,27 @@ class ReasonerBase(abc.ABC): | |||
| The revision score for the given solution. | |||
| """ | |||
| revision_idx = np.where(sol.get_x() != 0)[0] | |||
| candidates = self.revise_by_idx(pred_res, y, revision_idx) | |||
| candidates = self.revise_by_idx(pseudo_label, y, revision_idx) | |||
| if len(candidates) > 0: | |||
| return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) | |||
| return np.min(self._get_cost_list(pseudo_label, pred_res_prob, candidates)) | |||
| else: | |||
| return len(pred_res) | |||
| def _constrain_revision_num(self, solution, max_revision_num): | |||
| x = solution.get_x() | |||
| return max_revision_num - x.sum() | |||
| def zoopt_get_solution(self, pred_res, pred_res_prob, y, max_revision_num): | |||
| def zoopt_get_solution( | |||
| self, pred_res, pseudo_label, pred_res_prob, y, max_revision_num | |||
| ): | |||
| """Get the optimal solution using the Zoopt library. | |||
| Parameters | |||
| ---------- | |||
| pred_res : list | |||
| List of predicted results. | |||
| pseudo_label : list | |||
| List of predicted pseudo labels. | |||
| pred_res_prob : list | |||
| List of probabilities for predicted results. | |||
| y : str | |||
| @@ -127,21 +141,23 @@ class ReasonerBase(abc.ABC): | |||
| length = len(flatten(pred_res)) | |||
| dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length) | |||
| objective = Objective( | |||
| lambda sol: self.zoopt_revision_score(pred_res, pred_res_prob, y, sol), | |||
| lambda sol: self.zoopt_revision_score( | |||
| pred_res, pseudo_label, pred_res_prob, y, sol | |||
| ), | |||
| dim=dimension, | |||
| constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
| ) | |||
| parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
| solution = Opt.min(objective, parameter).get_x() | |||
| return solution | |||
| def revise_by_idx(self, pred_res, y, revision_idx): | |||
| def revise_by_idx(self, pseudo_label, y, revision_idx): | |||
| """Get the revisions corresponding to the given indices. | |||
| Parameters | |||
| ---------- | |||
| pred_res : list | |||
| List of predicted results. | |||
| pseudo_label : list | |||
| List of predicted pseudo labels. | |||
| y : str | |||
| Ground truth for the predicted results. | |||
| revision_idx : array-like | |||
| @@ -152,7 +168,7 @@ class ReasonerBase(abc.ABC): | |||
| list | |||
| The revisions corresponding to the given indices. | |||
| """ | |||
| return self.kb.revise_by_idx(pred_res, y, revision_idx) | |||
| return self.kb.revise_by_idx(pseudo_label, y, revision_idx) | |||
| def abduce(self, data, max_revision=-1, require_more_revision=0): | |||
| """Perform abduction on the given data. | |||
| @@ -162,7 +178,7 @@ class ReasonerBase(abc.ABC): | |||
| data : tuple | |||
| Tuple containing the predicted results, predicted result probabilities, and y. | |||
| max_revision : int or float, optional | |||
| Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||
| Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||
| If -1, use all revisions. Defaults to -1. | |||
| require_more_revision : int, optional | |||
| Number of additional revisions to require. Defaults to 0. | |||
| @@ -173,16 +189,22 @@ class ReasonerBase(abc.ABC): | |||
| The abduced revisions. | |||
| """ | |||
| pred_res, pred_res_prob, y = data | |||
| pseudo_label = [self.mapping[_idx] for _idx in pred_res] | |||
| max_revision_num = float_parameter(max_revision, len(flatten(pred_res))) | |||
| if self.zoopt: | |||
| solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_revision_num) | |||
| solution = self.zoopt_get_solution( | |||
| pred_res, pseudo_label, pred_res_prob, y, max_revision_num | |||
| ) | |||
| revision_idx = np.where(solution != 0)[0] | |||
| candidates = self.revise_by_idx(pred_res, y, revision_idx) | |||
| candidates = self.revise_by_idx(pseudo_label, y, revision_idx) | |||
| else: | |||
| candidates = self.kb.abduce_candidates(pred_res, y, max_revision_num, require_more_revision) | |||
| candidates = self.kb.abduce_candidates( | |||
| pred_res, y, max_revision_num, require_more_revision | |||
| ) | |||
| candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) | |||
| candidate = self._get_one_candidate(pseudo_label, pred_res_prob, candidates) | |||
| return candidate | |||
| def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0): | |||
| @@ -195,7 +217,7 @@ class ReasonerBase(abc.ABC): | |||
| Y : list | |||
| List of ground truths. | |||
| max_revision : int or float, optional | |||
| Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||
| Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||
| If -1, use all revisions. Defaults to -1. | |||
| require_more_revision : int, optional | |||
| Number of additional revisions to require. Defaults to 0. | |||
| @@ -205,8 +227,11 @@ class ReasonerBase(abc.ABC): | |||
| list | |||
| The abduced revisions. | |||
| """ | |||
| return [self.abduce((z, prob, y), max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] | |||
| return [ | |||
| self.abduce((z, prob, y), max_revision, require_more_revision) | |||
| for z, prob, y in zip(Z["label"], Z["prob"], Y) | |||
| ] | |||
| # def _batch_abduce_helper(self, args): | |||
| # z, prob, y, max_revision, require_more_revision = args | |||
| # return self.abduce((z, prob, y), max_revision, require_more_revision) | |||
| @@ -215,120 +240,224 @@ class ReasonerBase(abc.ABC): | |||
| # with Pool(processes=os.cpu_count()) as pool: | |||
| # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) | |||
| # return results | |||
| def __call__(self, Z, Y, max_revision=-1, require_more_revision=0): | |||
| return self.batch_abduce(Z, Y, max_revision, require_more_revision) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| from kb import KBBase, prolog_KB | |||
| prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| prob1 = [ | |||
| [ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| ] | |||
| prob2 = [ | |||
| [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| ] | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=list(range(10)), | |||
| len_list=[2], | |||
| GKB_flag=False, | |||
| max_err=0, | |||
| use_cache=True, | |||
| ): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| print('add_KB with GKB:') | |||
| print("add_KB with GKB:") | |||
| kb = add_KB(GKB_flag=True) | |||
| reasoner = ReasonerBase(kb, 'confidence') | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('add_KB without GKB:') | |||
| print("add_KB without GKB:") | |||
| kb = add_KB() | |||
| reasoner = ReasonerBase(kb, 'confidence') | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('add_KB without GKB:, no cache') | |||
| print("add_KB without GKB:, no cache") | |||
| kb = add_KB(use_cache=False) | |||
| reasoner = ReasonerBase(kb, 'confidence') | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('prolog_KB with add.pl:') | |||
| kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl') | |||
| reasoner = ReasonerBase(kb, 'confidence') | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) | |||
| print("prolog_KB with add.pl:") | |||
| kb = prolog_KB( | |||
| pseudo_label_list=list(range(10)), | |||
| pl_file="../examples/mnist_add/datasets/add.pl", | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('prolog_KB with add.pl using zoopt:') | |||
| kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl') | |||
| reasoner = ReasonerBase(kb, 'confidence', zoopt=True) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0) | |||
| print("prolog_KB with add.pl using zoopt:") | |||
| kb = prolog_KB( | |||
| pseudo_label_list=list(range(10)), | |||
| pl_file="../examples/mnist_add/datasets/add.pl", | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence", zoopt=True) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('add_KB with multiple inputs at once:') | |||
| multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], | |||
| [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| print("add_KB with multiple inputs at once:") | |||
| multiple_prob = [ | |||
| [ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ], | |||
| [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ], | |||
| ] | |||
| kb = add_KB() | |||
| reasoner = ReasonerBase(kb, 'confidence') | |||
| res = reasoner.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=1) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1], [1, 2]], "prob": multiple_prob}, | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [[1, 1], [1, 2]], "prob": multiple_prob}, | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=1, | |||
| ) | |||
| print(res) | |||
| print() | |||
| class HWF_KB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||
| self, | |||
| pseudo_label_list=[ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| "+", | |||
| "-", | |||
| "times", | |||
| "div", | |||
| ], | |||
| len_list=[1, 3, 5, 7], | |||
| GKB_flag=False, | |||
| max_err=1e-3, | |||
| use_cache=True | |||
| use_cache=True, | |||
| ): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) | |||
| @@ -336,9 +465,19 @@ if __name__ == '__main__': | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| if i % 2 == 0 and formula[i] not in [ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| ]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
| return False | |||
| return True | |||
| @@ -346,91 +485,183 @@ if __name__ == '__main__': | |||
| if not self._valid_candidate(formula): | |||
| return np.inf | |||
| mapping = {str(i): str(i) for i in range(1, 10)} | |||
| mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'}) | |||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval(''.join(formula)) | |||
| print('HWF_KB with GKB, max_err=0.1') | |||
| kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) | |||
| reasoner = ReasonerBase(kb, 'hamming') | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) | |||
| return eval("".join(formula)) | |||
| print("HWF_KB with GKB, max_err=0.1") | |||
| kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"]], "prob": [None]}, | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "9"]], "prob": [None]}, | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('HWF_KB without GKB, max_err=0.1') | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) | |||
| reasoner = ReasonerBase(kb, 'hamming') | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) | |||
| print("HWF_KB without GKB, max_err=0.1") | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"]], "prob": [None]}, | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "9"]], "prob": [None]}, | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('HWF_KB with GKB, max_err=1') | |||
| kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) | |||
| reasoner = ReasonerBase(kb, 'hamming') | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) | |||
| print("HWF_KB with GKB, max_err=1") | |||
| kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "9"]], "prob": [None]}, | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"]], "prob": [None]}, | |||
| [1.67], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('HWF_KB without GKB, max_err=1') | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err = 1) | |||
| reasoner = ReasonerBase(kb, 'hamming') | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3) | |||
| print("HWF_KB without GKB, max_err=1") | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "9"]], "prob": [None]}, | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"]], "prob": [None]}, | |||
| [1.67], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "8", "8", "8", "8"]], "prob": [None]}, | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('HWF_KB with multiple inputs at once:') | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) | |||
| reasoner = ReasonerBase(kb, 'hamming') | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=1, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 65], max_revision=3, require_more_revision=0) | |||
| print("HWF_KB with multiple inputs at once:") | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, | |||
| [3, 64], | |||
| max_revision=1, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, | |||
| [3, 64], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, | |||
| [3, 65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('max_revision is float') | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.5, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.9, require_more_revision=0) | |||
| print("max_revision is float") | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, | |||
| [3, 64], | |||
| max_revision=0.5, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| {"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]}, | |||
| [3, 64], | |||
| max_revision=0.9, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| print() | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| def consist_rule(self, exs, rules): | |||
| rules = str(rules).replace("\'","") | |||
| return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
| rules = str(rules).replace("'", "") | |||
| return ( | |||
| len( | |||
| list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules))) | |||
| ) | |||
| != 0 | |||
| ) | |||
| def abduce_rules(self, pred_res): | |||
| prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||
| prolog_result = list( | |||
| self.prolog.query("consistent_inst_feature(%s, X)." % pred_res) | |||
| ) | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]['X'] | |||
| prolog_rules = prolog_result[0]["X"] | |||
| rules = [rule.value for rule in prolog_rules] | |||
| return rules | |||
| class HED_Reasoner(ReasonerBase): | |||
| def __init__(self, kb, dist_func='hamming'): | |||
| def __init__(self, kb, dist_func="hamming"): | |||
| super().__init__(kb, dist_func, zoopt=True) | |||
| def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs): | |||
| pred = [] | |||
| k = [] | |||
| @@ -439,14 +670,14 @@ if __name__ == '__main__': | |||
| pred.append(pred_res[idx]) | |||
| k.append(y[idx]) | |||
| revision_flag += list(all_revision_flag[idx]) | |||
| revision_idx = np.where(np.array(revision_flag) != 0)[0] | |||
| revision_idx = np.where(np.array(revision_flag) != 0)[0] | |||
| candidate = self.revise_by_idx(pred, k, revision_idx) | |||
| return candidate | |||
| def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): | |||
| def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): | |||
| all_revision_flag = reform_idx(sol.get_x(), pred_res) | |||
| lefted_idxs = [i for i in range(len(pred_res))] | |||
| candidate_size = [] | |||
| candidate_size = [] | |||
| while lefted_idxs: | |||
| idxs = [] | |||
| idxs.append(lefted_idxs.pop(0)) | |||
| @@ -455,21 +686,26 @@ if __name__ == '__main__': | |||
| for idx in range(-1, len(pred_res)): | |||
| if (not idx in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidate = self._revise_by_idxs(pred_res, y, all_revision_flag, idxs) | |||
| candidate = self._revise_by_idxs( | |||
| pred_res, y, all_revision_flag, idxs | |||
| ) | |||
| if len(candidate) == 0: | |||
| if len(idxs) > 1: | |||
| idxs.pop() | |||
| else: | |||
| if len(idxs) > len(max_candidate_idxs): | |||
| found = True | |||
| max_candidate_idxs = idxs.copy() | |||
| max_candidate_idxs = idxs.copy() | |||
| removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
| if found: | |||
| candidate_size.append(len(removed) + 1) | |||
| lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
| lefted_idxs = [ | |||
| i for i in lefted_idxs if i not in max_candidate_idxs | |||
| ] | |||
| candidate_size.sort() | |||
| score = 0 | |||
| import math | |||
| for i in range(0, len(candidate_size)): | |||
| score -= math.exp(-i) * candidate_size[i] | |||
| return score | |||
| @@ -477,31 +713,49 @@ if __name__ == '__main__': | |||
| def abduce_rules(self, pred_res): | |||
| return self.kb.abduce_rules(pred_res) | |||
| kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/hed/datasets/learn_add.pl') | |||
| kb = HED_prolog_KB( | |||
| pseudo_label_list=[1, 0, "+", "="], | |||
| pl_file="../examples/hed/datasets/learn_add.pl", | |||
| ) | |||
| reasoner = HED_Reasoner(kb) | |||
| consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] | |||
| inconsist_exs1 = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0], [0, '+', 0, '=', 1]] | |||
| inconsist_exs2 = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] | |||
| rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])'] | |||
| print('HED_kb logic forward') | |||
| consist_exs = [ | |||
| [1, 1, "+", 0, "=", 1, 1], | |||
| [1, "+", 1, "=", 1, 0], | |||
| [0, "+", 0, "=", 0], | |||
| ] | |||
| inconsist_exs1 = [ | |||
| [1, 1, "+", 0, "=", 1, 1], | |||
| [1, "+", 1, "=", 1, 0], | |||
| [0, "+", 0, "=", 0], | |||
| [0, "+", 0, "=", 1], | |||
| ] | |||
| inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||
| rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | |||
| print("HED_kb logic forward") | |||
| print(kb.logic_forward(consist_exs)) | |||
| print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | |||
| print() | |||
| print('HED_kb consist rule') | |||
| print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules)) | |||
| print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) | |||
| print("HED_kb consist rule") | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | |||
| print() | |||
| print('HED_Reasoner abduce') | |||
| res = reasoner.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))) | |||
| print("HED_Reasoner abduce") | |||
| res = reasoner.abduce( | |||
| (consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1))) | |||
| res = reasoner.abduce( | |||
| (inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce((inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2))) | |||
| res = reasoner.abduce( | |||
| (inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2)) | |||
| ) | |||
| print(res) | |||
| print() | |||
| print('HED_Reasoner abduce rules') | |||
| print("HED_Reasoner abduce rules") | |||
| abduced_rules = reasoner.abduce_rules(consist_exs) | |||
| print(abduced_rules) | |||
| print(abduced_rules) | |||