diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index ff33376..49daeab 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -21,39 +21,25 @@ class SimpleBridge(BaseBridge): super().__init__(model, reasoner) self.metric_list = metric_list - # TODO: add reasoner.mapping to the property of SimpleBridge - def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: self.model.predict(data_samples) return data_samples.pred_idx, data_samples.pred_prob - def abduce_pseudo_label( - self, - data_samples: ListData, - max_revision: int = -1, - require_more_revision: int = 0, - ) -> List[List[Any]]: - self.reasoner.batch_abduce(data_samples, max_revision, require_more_revision) + def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: + self.reasoner.batch_abduce(data_samples) return data_samples.abduced_pseudo_label - def idx_to_pseudo_label( - self, data_samples: ListData, mapping: Optional[Dict] = None - ) -> List[List[Any]]: - if mapping is None: - mapping = self.reasoner.mapping + def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: pred_idx = data_samples.pred_idx data_samples.pred_pseudo_label = [ - [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx + [self.reasoner.mapping[_idx] for _idx in sub_list] + for sub_list in pred_idx ] return data_samples.pred_pseudo_label - def pseudo_label_to_idx( - self, data_samples: ListData, mapping: Optional[Dict] = None - ) -> List[List[Any]]: - if mapping is None: - mapping = self.reasoner.remapping + def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: abduced_idx = [ - [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] + [self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] for sub_list in data_samples.abduced_pseudo_label ] data_samples.abduced_idx = abduced_idx diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 8f98904..2923828 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -24,17 +24,35 @@ class ReasonerBase: mapping : dict, optional A mapping from index in the base model to label. If not provided, a default order-based mapping is created. + max_revision : int or float, optional + The upper limit on the number of revisions for each data sample when + performing abductive reasoning. If float, denotes the fraction of the total + length that can be revised. A value of -1 implies no restriction on the + number of revisions. Defaults to -1. + require_more_revision : int, optional + Specifies additional number of revisions permitted beyond the minimum required + when performing abductive reasoning. Defaults to 0. use_zoopt : bool, optional Whether to use the Zoopt library during abductive reasoning. Defaults to False. """ - def __init__(self, kb, dist_func="confidence", mapping=None, use_zoopt=False): + def __init__(self, + kb, + dist_func="confidence", + mapping=None, + max_revision=-1, + require_more_revision=0, + use_zoopt=False, + ): if dist_func not in ["hamming", "confidence"]: raise NotImplementedError("Valid options for dist_func include \"hamming\" and \"confidence\"") self.kb = kb self.dist_func = dist_func self.use_zoopt = use_zoopt + self.max_revision = max_revision + self.require_more_revision = require_more_revision + if mapping is None: self.mapping = { index: label for index, label in enumerate(self.kb.pseudo_label_list) @@ -117,9 +135,7 @@ class ReasonerBase: size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num ) objective = Objective( - lambda sol: self.zoopt_revision_score( - symbol_num, data_sample, sol - ), + lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol), dim=dimension, constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), ) @@ -133,7 +149,9 @@ class ReasonerBase: has a higher preference for this solution. """ revision_idx = np.where(sol.get_x() != 0)[0] - candidates = self.revise_at_idx(data_sample, revision_idx) + candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label, + data_sample.Y, + revision_idx) if len(candidates) > 0: return np.min(self._get_cost_list(data_sample, candidates)) else: @@ -146,27 +164,6 @@ class ReasonerBase: """ x = solution.get_x() return max_revision_num - x.sum() - - def revise_at_idx(self, data_sample, revision_idx): - """ - Revise the pseudo label in the data sample at specified index positions. - - Parameters - ---------- - data_sample : ListData - Data sample. - revision_idx : array-like - Indices of where revisions should be made to the predicted pseudo label. - """ - return self.kb.revise_at_idx(data_sample.pred_pseudo_label, - data_sample.Y, - revision_idx) - - def abduce_candidates(self, data_sample, max_revision_num, require_more_revision): - return self.kb.abduce_candidates(data_sample.pred_pseudo_label, - data_sample.Y, - max_revision_num, - require_more_revision) def _get_max_revision_num(self, max_revision, symbol_num): """ @@ -186,9 +183,7 @@ class ReasonerBase: raise ValueError("If max_revision is an int, it must be non-negative.") return max_revision - def abduce( - self, data_sample, max_revision=-1, require_more_revision=0 - ): + def abduce(self, data_sample): """ Perform abductive reasoning on the given data sample. @@ -196,14 +191,7 @@ class ReasonerBase: ---------- data_sample : ListData Data sample. - max_revision : int or float, optional - The upper limit on the number of revisions. If float, denotes the fraction of the - total length that can be revised. A value of -1 implies no restriction on the number - of revisions. Defaults to -1. - require_more_revision : int, optional - Specifies additional number of revisions permitted beyond the minimum required. - Defaults to 0. - + Returns ------- List[Any] @@ -211,39 +199,33 @@ class ReasonerBase: knowledge base. """ symbol_num = data_sample.elements_num("pred_pseudo_label") - max_revision_num = self._get_max_revision_num(max_revision, symbol_num) + max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) if self.use_zoopt: - solution = self.zoopt_get_solution( - symbol_num, data_sample, max_revision_num - ) + solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) revision_idx = np.where(solution != 0)[0] - candidates = self.revise_at_idx(data_sample, revision_idx) + candidates = self.self.kb.revise_at_idx(data_sample.pred_pseudo_label, + data_sample.Y, + revision_idx) else: - candidates = self.abduce_candidates( - data_sample, max_revision_num, require_more_revision - ) - + candidates = self.kb.abduce_candidates(data_sample.pred_pseudo_label, + data_sample.Y, + max_revision_num, + self.require_more_revision) + candidate = self._get_one_candidate(data_sample, candidates) return candidate - def batch_abduce( - self, data_samples, max_revision=-1, require_more_revision=0 - ): + def batch_abduce(self, data_samples): """ Perform abductive reasoning on the given prediction data samples. For detailed information, refer to `abduce`. """ abduced_pseudo_label = [ - self.abduce(data_sample, max_revision, require_more_revision) - for data_sample in data_samples + self.abduce(data_sample) for data_sample in data_samples ] data_samples.abduced_pseudo_label = abduced_pseudo_label return abduced_pseudo_label - def __call__( - self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 - ): - return self.batch_abduce( - pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision - ) + def __call__(self, data_samples): + return self.batch_abduce(data_samples)