From bf04dd9c95dc3506a771717bd6f40c43322d07ac Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Wed, 22 Nov 2023 12:33:12 +0800 Subject: [PATCH] [ENH] change parameters passing in reasoning --- abl/reasoning/reasoner.py | 100 ++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 53 deletions(-) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 0786d06..aec9b01 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,6 +1,6 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import ( +from abl.utils.utils import ( confidence_dist, flatten, reform_list, @@ -50,7 +50,7 @@ class ReasonerBase: self.mapping = mapping self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) - def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): + def _get_one_candidate(self, data_sample, candidates): """ Due to the nondeterminism of abductive reasoning, there could be multiple candidates satisfying the knowledge base. When this happens, return one candidate that has the @@ -71,11 +71,11 @@ class ReasonerBase: elif len(candidates) == 1: return candidates[0] else: - cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) + cost_array = self._get_cost_list(data_sample, candidates) candidate = candidates[np.argmin(cost_array)] return candidate - def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates): + def _get_cost_list(self, data_sample, candidates): """ Get the list of costs between each candidate and the given prediction. The list is calculated based on one of the following distance functions: @@ -95,15 +95,15 @@ class ReasonerBase: Multiple consistent candidates. """ if self.dist_func == "hamming": - return hamming_dist(pred_pseudo_label, candidates) + return hamming_dist(data_sample.pred_pseudo_label, candidates) elif self.dist_func == "confidence": candidates = [[self.remapping[x] for x in c] for c in candidates] - return confidence_dist(pred_prob, candidates) + return confidence_dist(data_sample.pred_prob, candidates) def zoopt_get_solution( - self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num + self, symbol_num, data_sample, max_revision_num ): """ Get the optimal solution using the Zoopt library. The solution is a list of @@ -113,13 +113,8 @@ class ReasonerBase: ---------- symbol_num : int Number of total symbols. - pred_pseudo_label : List[Any] - Predicted pseudo label. - pred_prob : List[List[Any]] - Predicted probabilities of the prediction (Each sublist contains the probability - distribution over all pseudo labels). - y : Any - Ground truth for the logical result. + data_sample : ListData + max_revision_num : int Specifies the maximum number of revisions allowed. """ @@ -128,7 +123,7 @@ class ReasonerBase: ) objective = Objective( lambda sol: self.zoopt_revision_score( - symbol_num, pred_pseudo_label, pred_prob, y, sol + symbol_num, data_sample, sol ), dim=dimension, constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), @@ -137,15 +132,15 @@ class ReasonerBase: solution = Opt.min(objective, parameter).get_x() return solution - def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol): + def zoopt_revision_score(self, symbol_num, data_sample, sol): """ Get the revision score for a solution. A lower score suggests that the Zoopt library has a higher preference for this solution. """ revision_idx = np.where(sol.get_x() != 0)[0] - candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx) + candidates = self.revise_at_idx(data_sample, revision_idx) if len(candidates) > 0: - return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates)) + return np.min(self._get_cost_list(data_sample, candidates)) else: return symbol_num @@ -157,7 +152,7 @@ class ReasonerBase: x = solution.get_x() return max_revision_num - x.sum() - def revise_at_idx(self, pred_pseudo_label, y, revision_idx): + def revise_at_idx(self, data_sample, revision_idx): """ Revise the predicted pseudo label at specified index positions. @@ -170,7 +165,15 @@ class ReasonerBase: revision_idx : array-like Indices of where revisions should be made to the predicted pseudo label. """ - return self.kb.revise_at_idx(pred_pseudo_label, y, revision_idx) + return self.kb.revise_at_idx(data_sample.pred_pseudo_label, + data_sample.Y, + revision_idx) + + def abduce_candidates(self, data_sample, max_revision_num, require_more_revision): + return self.kb.abduce_candidates(data_sample.pred_pseudo_label, + data_sample.Y, + max_revision_num, + require_more_revision) def _get_max_revision_num(self, max_revision, symbol_num): """ @@ -222,22 +225,18 @@ class ReasonerBase: symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - pred_pseudo_label = data_sample.pred_pseudo_label - pred_prob = data_sample.pred_prob - y = data_sample.Y - if self.use_zoopt: solution = self.zoopt_get_solution( - symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num + symbol_num, data_sample, max_revision_num ) revision_idx = np.where(solution != 0)[0] - candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx) + candidates = self.revise_at_idx(data_sample, revision_idx) else: - candidates = self.kb.abduce_candidates( - pred_pseudo_label, y, max_revision_num, require_more_revision + candidates = self.abduce_candidates( + data_sample, max_revision_num, require_more_revision ) - candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) + candidate = self._get_one_candidate(data_sample, candidates) return candidate def batch_abduce( @@ -254,15 +253,6 @@ class ReasonerBase: data_samples.abduced_pseudo_label = abduced_pseudo_label return abduced_pseudo_label - # def _batch_abduce_helper(self, args): - # z, prob, y, max_revision, require_more_revision = args - # return self.abduce((z, prob, y), max_revision, require_more_revision) - - # def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0): - # with Pool(processes=os.cpu_count()) as pool: - # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) - # return results - def __call__( self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 ): @@ -486,8 +476,8 @@ if __name__ == "__main__": pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) return len(list(self.prolog.query(pl_query))) != 0 - def abduce_rules(self, pred_res): - pl_query = "consistent_inst_feature(%s, X)." % pred_res + def abduce_rules(self, pseudo_labels): + pl_query = "consistent_inst_feature(%s, X)." % pseudo_labels prolog_result = list(self.prolog.query(pl_query)) if len(prolog_result) == 0: return None @@ -499,32 +489,36 @@ if __name__ == "__main__": def __init__(self, kb, dist_func="hamming"): super().__init__(kb, dist_func, use_zoopt=True) - def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs): - pred = [] - k = [] + def _revise_at_idxs(self, pseudo_labels, ys, all_revision_flag, idxs): + data_sample = ListData() + data_sample.pred_pseudo_label = [] + data_sample.Y = [] revision_flag = [] for idx in idxs: - pred.append(pred_res[idx]) - k.append(y[idx]) + data_sample.pred_pseudo_label.append(pseudo_labels[idx]) + data_sample.Y.append(ys[idx]) revision_flag += list(all_revision_flag[idx]) revision_idx = np.where(np.array(revision_flag) != 0)[0] - candidate = self.revise_at_idx(pred, k, revision_idx) + candidate = self.revise_at_idx(data_sample, revision_idx) return candidate - def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): - all_revision_flag = reform_list(sol.get_x(), pred_res) - lefted_idxs = [i for i in range(len(pred_res))] + def zoopt_revision_score(self, symbol_num, data_sample, sol): + pseudo_labels = data_sample.pred_pseudo_label + ys = data_sample.Y + + all_revision_flag = reform_list(sol.get_x(), pseudo_labels) + lefted_idxs = [i for i in range(len(pseudo_labels))] candidate_size = [] while lefted_idxs: idxs = [] idxs.append(lefted_idxs.pop(0)) max_candidate_idxs = [] found = False - for idx in range(-1, len(pred_res)): + for idx in range(-1, len(pseudo_labels)): if (not idx in idxs) and (idx >= 0): idxs.append(idx) candidate = self._revise_at_idxs( - pred_res, y, all_revision_flag, idxs + pseudo_labels, ys, all_revision_flag, idxs ) if len(candidate) == 0: if len(idxs) > 1: @@ -547,8 +541,8 @@ if __name__ == "__main__": score -= math.exp(-i) * candidate_size[i] return score - def abduce_rules(self, pred_res): - return self.kb.abduce_rules(pred_res) + def abduce_rules(self, pseudo_labels): + return self.kb.abduce_rules(pseudo_labels) kb = HedKB( pseudo_label_list=[1, 0, "+", "="],