| @@ -1,6 +1,6 @@ | |||
| import numpy as np | |||
| from zoopt import Dimension, Objective, Parameter, Opt | |||
| from ..utils.utils import ( | |||
| from abl.utils.utils import ( | |||
| confidence_dist, | |||
| flatten, | |||
| reform_list, | |||
| @@ -50,7 +50,7 @@ class ReasonerBase: | |||
| self.mapping = mapping | |||
| self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||
| def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): | |||
| def _get_one_candidate(self, data_sample, candidates): | |||
| """ | |||
| Due to the nondeterminism of abductive reasoning, there could be multiple candidates | |||
| satisfying the knowledge base. When this happens, return one candidate that has the | |||
| @@ -71,11 +71,11 @@ class ReasonerBase: | |||
| elif len(candidates) == 1: | |||
| return candidates[0] | |||
| else: | |||
| cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) | |||
| cost_array = self._get_cost_list(data_sample, candidates) | |||
| candidate = candidates[np.argmin(cost_array)] | |||
| return candidate | |||
| def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates): | |||
| def _get_cost_list(self, data_sample, candidates): | |||
| """ | |||
| Get the list of costs between each candidate and the given prediction. The list is | |||
| calculated based on one of the following distance functions: | |||
| @@ -95,15 +95,15 @@ class ReasonerBase: | |||
| Multiple consistent candidates. | |||
| """ | |||
| if self.dist_func == "hamming": | |||
| return hamming_dist(pred_pseudo_label, candidates) | |||
| return hamming_dist(data_sample.pred_pseudo_label, candidates) | |||
| elif self.dist_func == "confidence": | |||
| candidates = [[self.remapping[x] for x in c] for c in candidates] | |||
| return confidence_dist(pred_prob, candidates) | |||
| return confidence_dist(data_sample.pred_prob, candidates) | |||
| def zoopt_get_solution( | |||
| self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
| self, symbol_num, data_sample, max_revision_num | |||
| ): | |||
| """ | |||
| Get the optimal solution using the Zoopt library. The solution is a list of | |||
| @@ -113,13 +113,8 @@ class ReasonerBase: | |||
| ---------- | |||
| symbol_num : int | |||
| Number of total symbols. | |||
| pred_pseudo_label : List[Any] | |||
| Predicted pseudo label. | |||
| pred_prob : List[List[Any]] | |||
| Predicted probabilities of the prediction (Each sublist contains the probability | |||
| distribution over all pseudo labels). | |||
| y : Any | |||
| Ground truth for the logical result. | |||
| data_sample : ListData | |||
| max_revision_num : int | |||
| Specifies the maximum number of revisions allowed. | |||
| """ | |||
| @@ -128,7 +123,7 @@ class ReasonerBase: | |||
| ) | |||
| objective = Objective( | |||
| lambda sol: self.zoopt_revision_score( | |||
| symbol_num, pred_pseudo_label, pred_prob, y, sol | |||
| symbol_num, data_sample, sol | |||
| ), | |||
| dim=dimension, | |||
| constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
| @@ -137,15 +132,15 @@ class ReasonerBase: | |||
| solution = Opt.min(objective, parameter).get_x() | |||
| return solution | |||
| def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol): | |||
| def zoopt_revision_score(self, symbol_num, data_sample, sol): | |||
| """ | |||
| Get the revision score for a solution. A lower score suggests that the Zoopt library | |||
| has a higher preference for this solution. | |||
| """ | |||
| revision_idx = np.where(sol.get_x() != 0)[0] | |||
| candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx) | |||
| candidates = self.revise_at_idx(data_sample, revision_idx) | |||
| if len(candidates) > 0: | |||
| return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates)) | |||
| return np.min(self._get_cost_list(data_sample, candidates)) | |||
| else: | |||
| return symbol_num | |||
| @@ -157,7 +152,7 @@ class ReasonerBase: | |||
| x = solution.get_x() | |||
| return max_revision_num - x.sum() | |||
| def revise_at_idx(self, pred_pseudo_label, y, revision_idx): | |||
| def revise_at_idx(self, data_sample, revision_idx): | |||
| """ | |||
| Revise the predicted pseudo label at specified index positions. | |||
| @@ -170,7 +165,15 @@ class ReasonerBase: | |||
| revision_idx : array-like | |||
| Indices of where revisions should be made to the predicted pseudo label. | |||
| """ | |||
| return self.kb.revise_at_idx(pred_pseudo_label, y, revision_idx) | |||
| return self.kb.revise_at_idx(data_sample.pred_pseudo_label, | |||
| data_sample.Y, | |||
| revision_idx) | |||
| def abduce_candidates(self, data_sample, max_revision_num, require_more_revision): | |||
| return self.kb.abduce_candidates(data_sample.pred_pseudo_label, | |||
| data_sample.Y, | |||
| max_revision_num, | |||
| require_more_revision) | |||
| def _get_max_revision_num(self, max_revision, symbol_num): | |||
| """ | |||
| @@ -222,22 +225,18 @@ class ReasonerBase: | |||
| symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
| max_revision_num = self._get_max_revision_num(max_revision, symbol_num) | |||
| pred_pseudo_label = data_sample.pred_pseudo_label | |||
| pred_prob = data_sample.pred_prob | |||
| y = data_sample.Y | |||
| if self.use_zoopt: | |||
| solution = self.zoopt_get_solution( | |||
| symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
| symbol_num, data_sample, max_revision_num | |||
| ) | |||
| revision_idx = np.where(solution != 0)[0] | |||
| candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx) | |||
| candidates = self.revise_at_idx(data_sample, revision_idx) | |||
| else: | |||
| candidates = self.kb.abduce_candidates( | |||
| pred_pseudo_label, y, max_revision_num, require_more_revision | |||
| candidates = self.abduce_candidates( | |||
| data_sample, max_revision_num, require_more_revision | |||
| ) | |||
| candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) | |||
| candidate = self._get_one_candidate(data_sample, candidates) | |||
| return candidate | |||
| def batch_abduce( | |||
| @@ -254,15 +253,6 @@ class ReasonerBase: | |||
| data_samples.abduced_pseudo_label = abduced_pseudo_label | |||
| return abduced_pseudo_label | |||
| # def _batch_abduce_helper(self, args): | |||
| # z, prob, y, max_revision, require_more_revision = args | |||
| # return self.abduce((z, prob, y), max_revision, require_more_revision) | |||
| # def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0): | |||
| # with Pool(processes=os.cpu_count()) as pool: | |||
| # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) | |||
| # return results | |||
| def __call__( | |||
| self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| @@ -486,8 +476,8 @@ if __name__ == "__main__": | |||
| pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | |||
| return len(list(self.prolog.query(pl_query))) != 0 | |||
| def abduce_rules(self, pred_res): | |||
| pl_query = "consistent_inst_feature(%s, X)." % pred_res | |||
| def abduce_rules(self, pseudo_labels): | |||
| pl_query = "consistent_inst_feature(%s, X)." % pseudo_labels | |||
| prolog_result = list(self.prolog.query(pl_query)) | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| @@ -499,32 +489,36 @@ if __name__ == "__main__": | |||
| def __init__(self, kb, dist_func="hamming"): | |||
| super().__init__(kb, dist_func, use_zoopt=True) | |||
| def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs): | |||
| pred = [] | |||
| k = [] | |||
| def _revise_at_idxs(self, pseudo_labels, ys, all_revision_flag, idxs): | |||
| data_sample = ListData() | |||
| data_sample.pred_pseudo_label = [] | |||
| data_sample.Y = [] | |||
| revision_flag = [] | |||
| for idx in idxs: | |||
| pred.append(pred_res[idx]) | |||
| k.append(y[idx]) | |||
| data_sample.pred_pseudo_label.append(pseudo_labels[idx]) | |||
| data_sample.Y.append(ys[idx]) | |||
| revision_flag += list(all_revision_flag[idx]) | |||
| revision_idx = np.where(np.array(revision_flag) != 0)[0] | |||
| candidate = self.revise_at_idx(pred, k, revision_idx) | |||
| candidate = self.revise_at_idx(data_sample, revision_idx) | |||
| return candidate | |||
| def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): | |||
| all_revision_flag = reform_list(sol.get_x(), pred_res) | |||
| lefted_idxs = [i for i in range(len(pred_res))] | |||
| def zoopt_revision_score(self, symbol_num, data_sample, sol): | |||
| pseudo_labels = data_sample.pred_pseudo_label | |||
| ys = data_sample.Y | |||
| all_revision_flag = reform_list(sol.get_x(), pseudo_labels) | |||
| lefted_idxs = [i for i in range(len(pseudo_labels))] | |||
| candidate_size = [] | |||
| while lefted_idxs: | |||
| idxs = [] | |||
| idxs.append(lefted_idxs.pop(0)) | |||
| max_candidate_idxs = [] | |||
| found = False | |||
| for idx in range(-1, len(pred_res)): | |||
| for idx in range(-1, len(pseudo_labels)): | |||
| if (not idx in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidate = self._revise_at_idxs( | |||
| pred_res, y, all_revision_flag, idxs | |||
| pseudo_labels, ys, all_revision_flag, idxs | |||
| ) | |||
| if len(candidate) == 0: | |||
| if len(idxs) > 1: | |||
| @@ -547,8 +541,8 @@ if __name__ == "__main__": | |||
| score -= math.exp(-i) * candidate_size[i] | |||
| return score | |||
| def abduce_rules(self, pred_res): | |||
| return self.kb.abduce_rules(pred_res) | |||
| def abduce_rules(self, pseudo_labels): | |||
| return self.kb.abduce_rules(pseudo_labels) | |||
| kb = HedKB( | |||
| pseudo_label_list=[1, 0, "+", "="], | |||