From a816bf0b744610e7b2c7b29518aea0cfb781aa93 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 13 Oct 2023 10:18:46 +0800 Subject: [PATCH] [MNT] use black to reformat reasoner.py --- abl/reasoning/reasoner.py | 276 ++++++++++++++++++++++++++++---------- 1 file changed, 208 insertions(+), 68 deletions(-) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index f42ee51..2dc1cab 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,5 +1,4 @@ import numpy as np -from multiprocessing import Pool from zoopt import Dimension, Objective, Parameter, Opt from ..utils.utils import ( confidence_dist, @@ -10,7 +9,7 @@ from ..utils.utils import ( ) -class ReasonerBase(): +class ReasonerBase: def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): """ Root class for all reasoner in the ABL system. @@ -31,15 +30,17 @@ class ReasonerBase(): NotImplementedError If the specified distance function is neither "hamming" nor "confidence". """ - + if not (dist_func == "hamming" or dist_func == "confidence"): - raise NotImplementedError # Only hamming or confidence distance is available. + raise NotImplementedError # Only hamming or confidence distance is available. self.kb = kb self.dist_func = dist_func self.use_zoopt = use_zoopt if mapping is None: - self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} + self.mapping = { + index: label for index, label in enumerate(self.kb.pseudo_label_list) + } else: self.mapping = mapping self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) @@ -130,7 +131,9 @@ class ReasonerBase(): x = solution.get_x() return max_revision_num - x.sum() - def zoopt_get_solution(self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num): + def zoopt_get_solution( + self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num + ): """Get the optimal solution using the Zoopt library. Parameters @@ -151,9 +154,13 @@ class ReasonerBase(): array-like The optimal solution, i.e., where to revise predict pseudo label. """ - dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) + dimension = Dimension( + size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num + ) objective = Objective( - lambda sol: self.zoopt_revision_score(symbol_num, pred_pseudo_label, pred_prob, y, sol), + lambda sol: self.zoopt_revision_score( + symbol_num, pred_pseudo_label, pred_prob, y, sol + ), dim=dimension, constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), ) @@ -181,7 +188,9 @@ class ReasonerBase(): """ return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) - def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0): + def abduce( + self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 + ): """ Perform revision by abduction on the given data. @@ -208,16 +217,22 @@ class ReasonerBase(): max_revision_num = float_parameter(max_revision, symbol_num) if self.use_zoopt: - solution = self.zoopt_get_solution(symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num) + solution = self.zoopt_get_solution( + symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num + ) revision_idx = np.where(solution != 0)[0] candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) else: - candidates = self.kb.abduce_candidates(pred_pseudo_label, y, max_revision_num, require_more_revision) + candidates = self.kb.abduce_candidates( + pred_pseudo_label, y, max_revision_num, require_more_revision + ) candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) return candidate - def batch_abduce(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): + def batch_abduce( + self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 + ): """ Perform abduction on the given data in batches. @@ -240,9 +255,14 @@ class ReasonerBase(): list The abduced revisions in batches. """ - return [self.abduce(_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision) - for _pred_prob, _pred_pseudo_label, _Y in zip(pred_prob, pred_pseudo_label, Y)] - + return [ + self.abduce( + _pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision + ) + for _pred_prob, _pred_pseudo_label, _Y in zip( + pred_prob, pred_pseudo_label, Y + ) + ] # def _batch_abduce_helper(self, args): # z, prob, y, max_revision, require_more_revision = args @@ -253,8 +273,12 @@ class ReasonerBase(): # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) # return results - def __call__(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): - return self.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) + def __call__( + self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 + ): + return self.batch_abduce( + pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision + ) if __name__ == "__main__": @@ -282,7 +306,9 @@ if __name__ == "__main__": max_err=0, use_cache=True, ): - super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache) + super().__init__( + pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache + ) def logic_forward(self, nums): return sum(nums) @@ -290,45 +316,75 @@ if __name__ == "__main__": print("add_KB with GKB:") kb = add_KB(prebuild_GKB=True) reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 + ) print(res) print() print("add_KB without GKB:") kb = add_KB() reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 + ) print(res) print() print("add_KB without GKB:, no cache") kb = add_KB(use_cache=False) reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 + ) print(res) print() @@ -338,15 +394,25 @@ if __name__ == "__main__": pl_file="examples/mnist_add/datasets/add.pl", ) reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 + ) print(res) print() @@ -356,15 +422,25 @@ if __name__ == "__main__": pl_file="examples/mnist_add/datasets/add.pl", ) reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 + ) print(res) print() @@ -383,13 +459,19 @@ if __name__ == "__main__": kb = add_KB() reasoner = ReasonerBase(kb, "confidence") res = reasoner.batch_abduce( - [[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8], + [[1, 1], [1, 2]], + multiple_prob, + [[1, 1], [1, 2]], + [4, 8], max_revision=2, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8], + [[1, 1], [1, 2]], + multiple_prob, + [[1, 1], [1, 2]], + [4, 8], max_revision=2, require_more_revision=1, ) @@ -419,7 +501,9 @@ if __name__ == "__main__": max_err=1e-3, use_cache=True, ): - super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache) + super().__init__( + pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache + ) def _valid_candidate(self, formula): if len(formula) % 2 == 0: @@ -453,19 +537,28 @@ if __name__ == "__main__": kb = HWF_KB(prebuild_GKB=True, GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,2]], [3], + [["5", "+", "2"]], + [None], + [[5, 10, 2]], + [3], max_revision=2, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,9]], [65], + [["5", "+", "2"]], + [None], + [[5, 10, 9]], + [65], max_revision=3, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], + [["5", "8", "8", "8", "8"]], + [None], + [[5, 8, 8, 8, 8]], + [3.17], max_revision=5, require_more_revision=3, ) @@ -476,19 +569,28 @@ if __name__ == "__main__": kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,2]], [3], + [["5", "+", "2"]], + [None], + [[5, 10, 2]], + [3], max_revision=2, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,9]], [65], + [["5", "+", "2"]], + [None], + [[5, 10, 9]], + [65], max_revision=3, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], + [["5", "8", "8", "8", "8"]], + [None], + [[5, 8, 8, 8, 8]], + [3.17], max_revision=5, require_more_revision=3, ) @@ -499,19 +601,28 @@ if __name__ == "__main__": kb = HWF_KB(GKB_len_list=[1, 3, 5], prebuild_GKB=True, max_err=1) reasoner = ReasonerBase(kb, "hamming") res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,2]], [3], + [["5", "+", "2"]], + [None], + [[5, 10, 2]], + [3], max_revision=2, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,9]], [65], + [["5", "+", "2"]], + [None], + [[5, 10, 9]], + [65], max_revision=3, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], + [["5", "8", "8", "8", "8"]], + [None], + [[5, 8, 8, 8, 8]], + [3.17], max_revision=5, require_more_revision=3, ) @@ -522,19 +633,28 @@ if __name__ == "__main__": kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=1) reasoner = ReasonerBase(kb, "hamming") res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,2]], [3], + [["5", "+", "2"]], + [None], + [[5, 10, 2]], + [3], max_revision=2, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "+", "2"]], [None], [[5,10,9]], [65], + [["5", "+", "2"]], + [None], + [[5, 10, 9]], + [65], max_revision=3, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17], + [["5", "8", "8", "8", "8"]], + [None], + [[5, 8, 8, 8, 8]], + [3.17], max_revision=5, require_more_revision=3, ) @@ -545,21 +665,27 @@ if __name__ == "__main__": kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], + [["5", "+", "2"], ["5", "+", "9"]], + [None, None], + [[5, 10, 2], [5, 10, 9]], [3, 64], max_revision=1, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], + [["5", "+", "2"], ["5", "+", "9"]], + [None, None], + [[5, 10, 2], [5, 10, 9]], [3, 64], max_revision=3, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], + [["5", "+", "2"], ["5", "+", "9"]], + [None, None], + [[5, 10, 2], [5, 10, 9]], [3, 65], max_revision=3, require_more_revision=0, @@ -568,14 +694,18 @@ if __name__ == "__main__": print() print("max_revision is float") res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], + [["5", "+", "2"], ["5", "+", "9"]], + [None, None], + [[5, 10, 2], [5, 10, 9]], [3, 64], max_revision=0.5, require_more_revision=0, ) print(res) res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]], + [["5", "+", "2"], ["5", "+", "9"]], + [None, None], + [[5, 10, 2], [5, 10, 9]], [3, 64], max_revision=0.9, require_more_revision=0, @@ -629,7 +759,9 @@ if __name__ == "__main__": for idx in range(-1, len(pred_res)): if (not idx in idxs) and (idx >= 0): idxs.append(idx) - candidate = self._revise_by_idxs(pred_res, y, all_revision_flag, idxs) + candidate = self._revise_by_idxs( + pred_res, y, all_revision_flag, idxs + ) if len(candidate) == 0: if len(idxs) > 1: idxs.pop() @@ -640,7 +772,9 @@ if __name__ == "__main__": removed = [i for i in lefted_idxs if i in max_candidate_idxs] if found: candidate_size.append(len(removed) + 1) - lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] + lefted_idxs = [ + i for i in lefted_idxs if i not in max_candidate_idxs + ] candidate_size.sort() score = 0 import math @@ -681,11 +815,17 @@ if __name__ == "__main__": print() print("HED_Reasoner abduce") - res = reasoner.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))) + res = reasoner.abduce( + (consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)) + ) print(res) - res = reasoner.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1))) + res = reasoner.abduce( + (inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)) + ) print(res) - res = reasoner.abduce((inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2))) + res = reasoner.abduce( + (inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2)) + ) print(res) print()