| @@ -1,5 +1,4 @@ | |||
| import numpy as np | |||
| from multiprocessing import Pool | |||
| from zoopt import Dimension, Objective, Parameter, Opt | |||
| from ..utils.utils import ( | |||
| confidence_dist, | |||
| @@ -10,7 +9,7 @@ from ..utils.utils import ( | |||
| ) | |||
| class ReasonerBase(): | |||
| class ReasonerBase: | |||
| def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): | |||
| """ | |||
| Root class for all reasoner in the ABL system. | |||
| @@ -31,15 +30,17 @@ class ReasonerBase(): | |||
| NotImplementedError | |||
| If the specified distance function is neither "hamming" nor "confidence". | |||
| """ | |||
| if not (dist_func == "hamming" or dist_func == "confidence"): | |||
| raise NotImplementedError # Only hamming or confidence distance is available. | |||
| raise NotImplementedError # Only hamming or confidence distance is available. | |||
| self.kb = kb | |||
| self.dist_func = dist_func | |||
| self.use_zoopt = use_zoopt | |||
| if mapping is None: | |||
| self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} | |||
| self.mapping = { | |||
| index: label for index, label in enumerate(self.kb.pseudo_label_list) | |||
| } | |||
| else: | |||
| self.mapping = mapping | |||
| self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||
| @@ -130,7 +131,9 @@ class ReasonerBase(): | |||
| x = solution.get_x() | |||
| return max_revision_num - x.sum() | |||
| def zoopt_get_solution(self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num): | |||
| def zoopt_get_solution( | |||
| self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
| ): | |||
| """Get the optimal solution using the Zoopt library. | |||
| Parameters | |||
| @@ -151,9 +154,13 @@ class ReasonerBase(): | |||
| array-like | |||
| The optimal solution, i.e., where to revise predict pseudo label. | |||
| """ | |||
| dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | |||
| dimension = Dimension( | |||
| size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num | |||
| ) | |||
| objective = Objective( | |||
| lambda sol: self.zoopt_revision_score(symbol_num, pred_pseudo_label, pred_prob, y, sol), | |||
| lambda sol: self.zoopt_revision_score( | |||
| symbol_num, pred_pseudo_label, pred_prob, y, sol | |||
| ), | |||
| dim=dimension, | |||
| constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
| ) | |||
| @@ -181,7 +188,9 @@ class ReasonerBase(): | |||
| """ | |||
| return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||
| def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0): | |||
| def abduce( | |||
| self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| """ | |||
| Perform revision by abduction on the given data. | |||
| @@ -208,16 +217,22 @@ class ReasonerBase(): | |||
| max_revision_num = float_parameter(max_revision, symbol_num) | |||
| if self.use_zoopt: | |||
| solution = self.zoopt_get_solution(symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num) | |||
| solution = self.zoopt_get_solution( | |||
| symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
| ) | |||
| revision_idx = np.where(solution != 0)[0] | |||
| candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||
| else: | |||
| candidates = self.kb.abduce_candidates(pred_pseudo_label, y, max_revision_num, require_more_revision) | |||
| candidates = self.kb.abduce_candidates( | |||
| pred_pseudo_label, y, max_revision_num, require_more_revision | |||
| ) | |||
| candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) | |||
| return candidate | |||
| def batch_abduce(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): | |||
| def batch_abduce( | |||
| self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| """ | |||
| Perform abduction on the given data in batches. | |||
| @@ -240,9 +255,14 @@ class ReasonerBase(): | |||
| list | |||
| The abduced revisions in batches. | |||
| """ | |||
| return [self.abduce(_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision) | |||
| for _pred_prob, _pred_pseudo_label, _Y in zip(pred_prob, pred_pseudo_label, Y)] | |||
| return [ | |||
| self.abduce( | |||
| _pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision | |||
| ) | |||
| for _pred_prob, _pred_pseudo_label, _Y in zip( | |||
| pred_prob, pred_pseudo_label, Y | |||
| ) | |||
| ] | |||
| # def _batch_abduce_helper(self, args): | |||
| # z, prob, y, max_revision, require_more_revision = args | |||
| @@ -253,8 +273,12 @@ class ReasonerBase(): | |||
| # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) | |||
| # return results | |||
| def __call__(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): | |||
| return self.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) | |||
| def __call__( | |||
| self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| return self.batch_abduce( | |||
| pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision | |||
| ) | |||
| if __name__ == "__main__": | |||
| @@ -282,7 +306,9 @@ if __name__ == "__main__": | |||
| max_err=0, | |||
| use_cache=True, | |||
| ): | |||
| super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache) | |||
| super().__init__( | |||
| pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache | |||
| ) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| @@ -290,45 +316,75 @@ if __name__ == "__main__": | |||
| print("add_KB with GKB:") | |||
| kb = add_KB(prebuild_GKB=True) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| print("add_KB without GKB:") | |||
| kb = add_KB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| print("add_KB without GKB:, no cache") | |||
| kb = add_KB(use_cache=False) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| @@ -338,15 +394,25 @@ if __name__ == "__main__": | |||
| pl_file="examples/mnist_add/datasets/add.pl", | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| @@ -356,15 +422,25 @@ if __name__ == "__main__": | |||
| pl_file="examples/mnist_add/datasets/add.pl", | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| @@ -383,13 +459,19 @@ if __name__ == "__main__": | |||
| kb = add_KB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8], | |||
| [[1, 1], [1, 2]], | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8], | |||
| [[1, 1], [1, 2]], | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=1, | |||
| ) | |||
| @@ -419,7 +501,9 @@ if __name__ == "__main__": | |||
| max_err=1e-3, | |||
| use_cache=True, | |||
| ): | |||
| super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache) | |||
| super().__init__( | |||
| pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache | |||
| ) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| @@ -453,19 +537,28 @@ if __name__ == "__main__": | |||
| kb = HWF_KB(prebuild_GKB=True, GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,2]], [3], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,9]], [65], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| @@ -476,19 +569,28 @@ if __name__ == "__main__": | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,2]], [3], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,9]], [65], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| @@ -499,19 +601,28 @@ if __name__ == "__main__": | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], prebuild_GKB=True, max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,2]], [3], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,9]], [65], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| @@ -522,19 +633,28 @@ if __name__ == "__main__": | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,2]], [3], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], [None], [[5,10,9]], [65], | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| @@ -545,21 +665,27 @@ if __name__ == "__main__": | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=1, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| @@ -568,14 +694,18 @@ if __name__ == "__main__": | |||
| print() | |||
| print("max_revision is float") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=0.5, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=0.9, | |||
| require_more_revision=0, | |||
| @@ -629,7 +759,9 @@ if __name__ == "__main__": | |||
| for idx in range(-1, len(pred_res)): | |||
| if (not idx in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidate = self._revise_by_idxs(pred_res, y, all_revision_flag, idxs) | |||
| candidate = self._revise_by_idxs( | |||
| pred_res, y, all_revision_flag, idxs | |||
| ) | |||
| if len(candidate) == 0: | |||
| if len(idxs) > 1: | |||
| idxs.pop() | |||
| @@ -640,7 +772,9 @@ if __name__ == "__main__": | |||
| removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
| if found: | |||
| candidate_size.append(len(removed) + 1) | |||
| lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
| lefted_idxs = [ | |||
| i for i in lefted_idxs if i not in max_candidate_idxs | |||
| ] | |||
| candidate_size.sort() | |||
| score = 0 | |||
| import math | |||
| @@ -681,11 +815,17 @@ if __name__ == "__main__": | |||
| print() | |||
| print("HED_Reasoner abduce") | |||
| res = reasoner.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))) | |||
| res = reasoner.abduce( | |||
| (consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1))) | |||
| res = reasoner.abduce( | |||
| (inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce((inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2))) | |||
| res = reasoner.abduce( | |||
| (inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2)) | |||
| ) | |||
| print(res) | |||
| print() | |||