diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 8b71d9a..bc58699 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -15,7 +15,7 @@ class SimpleBridge(BaseBridge): self, model: ABLModel, abducer: ReasonerBase, - metric_list: BaseMetric, + metric_list: List[BaseMetric], ) -> None: super().__init__(model, abducer) self.metric_list = metric_list @@ -34,7 +34,7 @@ class SimpleBridge(BaseBridge): max_revision: int = -1, require_more_revision: int = 0, ) -> List[List[Any]]: - return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y, max_revision, require_more_revision) + return self.abducer.batch_abduce(pred_prob, pseudo_label, Y, max_revision, require_more_revision) def label_to_pseudo_label( self, label: List[List[Any]], mapping: Dict = None diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 631bfbd..93e005e 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -30,19 +30,19 @@ class ReasonerBase(): self.mapping = mapping self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) - def _get_cost_list(self, pseudo_label, pred_res_prob, candidates): + def _get_cost_list(self, pseudo_label, pred_prob, candidates): """ Get the list consisting of costs between each pseudo label and candidate. - Parameter `pred_res_prob` is needed while using confidence distance. + Parameter `pred_prob` is needed while using confidence distance. """ if self.dist_func == "hamming": return hamming_dist(pseudo_label, candidates) elif self.dist_func == "confidence": candidates = [[self.remapping[x] for x in c] for c in candidates] - return confidence_dist(pred_res_prob, candidates) + return confidence_dist(pred_prob, candidates) - def _get_one_candidate(self, pseudo_label, pred_res_prob, candidates): + def _get_one_candidate(self, pseudo_label, pred_prob, candidates): """ Get one candidate. If multiple candidates exist, return the one with minimum cost. """ @@ -51,11 +51,11 @@ class ReasonerBase(): elif len(candidates) == 1: return candidates[0] else: - cost_list = self._get_cost_list(pseudo_label, pred_res_prob, candidates) + cost_list = self._get_cost_list(pseudo_label, pred_prob, candidates) candidate = candidates[np.argmin(cost_list)] return candidate - def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, y, sol): + def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol): """ Get the revision score for a single solution. @@ -65,9 +65,9 @@ class ReasonerBase(): Solution to evaluate. pred_res : list List of predicted results. - pseudo_label : list + pred_pseudo_label : list List of predicted pseudo labels. - pred_res_prob : list + pred_prob : list List of probabilities for predicted results. y : str Ground truth for the predicted results. @@ -78,26 +78,26 @@ class ReasonerBase(): The revision score for the given solution. """ revision_idx = np.where(sol.get_x() != 0)[0] - candidates = self.revise_by_idx(pseudo_label, y, revision_idx) + candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) if len(candidates) > 0: - return np.min(self._get_cost_list(pseudo_label, pred_res_prob, candidates)) + return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates)) else: - return len(pred_res) + return symbol_num def _constrain_revision_num(self, solution, max_revision_num): x = solution.get_x() return max_revision_num - x.sum() - def zoopt_get_solution(self, pred_res, pseudo_label, pred_res_prob, y, max_revision_num): + def zoopt_get_solution(self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num): """Get the optimal solution using the Zoopt library. Parameters ---------- pred_res : list List of predicted results. - pseudo_label : list + pred_pseudo_label : list List of predicted pseudo labels. - pred_res_prob : list + pred_prob : list List of probabilities for predicted results. y : str Ground truth for the predicted results. @@ -109,10 +109,9 @@ class ReasonerBase(): array-like The optimal solution. """ - length = len(flatten(pred_res)) - dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length) + dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) objective = Objective( - lambda sol: self.zoopt_revision_score(pred_res, pseudo_label, pred_res_prob, y, sol), + lambda sol: self.zoopt_revision_score(symbol_num, pred_pseudo_label, pred_prob, y, sol), dim=dimension, constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), ) @@ -140,7 +139,7 @@ class ReasonerBase(): """ return self.kb.revise_by_idx(pseudo_label, y, revision_idx) - def abduce(self, data, max_revision=-1, require_more_revision=0): + def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0): """ Perform abduction on the given data. @@ -159,11 +158,11 @@ class ReasonerBase(): list The abduced revisions. """ - pred_label, pred_prob, pred_pseudo_label, y = data - max_revision_num = float_parameter(max_revision, len(flatten(pred_label))) + symbol_num = len(flatten(pred_pseudo_label)) + max_revision_num = float_parameter(max_revision, symbol_num) if self.use_zoopt: - solution = self.zoopt_get_solution(pred_label, pred_pseudo_label, pred_prob, y, max_revision_num) + solution = self.zoopt_get_solution(symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num) revision_idx = np.where(solution != 0)[0] candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) else: @@ -172,7 +171,7 @@ class ReasonerBase(): candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) return candidate - def batch_abduce(self, pred_label, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): + def batch_abduce(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): """Perform abduction on the given data in batches. Parameters @@ -192,10 +191,9 @@ class ReasonerBase(): list The abduced revisions. """ - return [ - self.abduce(data, max_revision, require_more_revision) - for data in zip(pred_label, pred_prob, pred_pseudo_label, Y) - ] + return [self.abduce(_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision) + for _pred_prob, _pred_pseudo_label, _Y in zip(pred_prob, pred_pseudo_label, Y)] + # def _batch_abduce_helper(self, args): # z, prob, y, max_revision, require_more_revision = args @@ -206,8 +204,8 @@ class ReasonerBase(): # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) # return results - def __call__(self, Z, Y, max_revision=-1, require_more_revision=0): - return self.batch_abduce(Z, Y, max_revision, require_more_revision) + def __call__(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): + return self.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) if __name__ == "__main__": @@ -570,7 +568,7 @@ if __name__ == "__main__": candidate = self.revise_by_idx(pred, k, revision_idx) return candidate - def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): + def zoopt_revision_score(self, pred_res, pred_prob, y, sol): all_revision_flag = reform_idx(sol.get_x(), pred_res) lefted_idxs = [i for i in range(len(pred_res))] candidate_size = []