From 4213fa00637f750f51cc60a64d9b6dc1df88bb9d Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Mon, 18 Dec 2023 14:06:44 +0800 Subject: [PATCH] [FIX] change name of mapping --- abl/bridge/simple_bridge.py | 4 ++-- abl/reasoning/reasoner.py | 32 ++++++++++++++++---------------- examples/hed/hed_bridge.py | 12 ++++++------ 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 5835619..06be00e 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -95,7 +95,7 @@ class SimpleBridge(BaseBridge): """ pred_idx = data_samples.pred_idx data_samples.pred_pseudo_label = [ - [self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx + [self.reasoner.idx_to_label[_idx] for _idx in sub_list] for sub_list in pred_idx ] return data_samples.pred_pseudo_label @@ -114,7 +114,7 @@ class SimpleBridge(BaseBridge): A list of indices converted from pseudo labels. """ abduced_idx = [ - [self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] + [self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] for sub_list in data_samples.abduced_pseudo_label ] data_samples.abduced_idx = abduced_idx diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 8431604..d311e7c 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -32,9 +32,9 @@ class Reasoner: in this cost list should be a numerical value representing the cost for each candidate, and the list should have the same length as candidates. Defaults to 'confidence'. - mapping : Optional[dict], optional + idx_to_label : Optional[dict], optional A mapping from index in the base model to label. If not provided, a default - order-based mapping is created. Defaults to None. + order-based index to label mapping is created. Defaults to None. max_revision : Union[int, float], optional The upper limit on the number of revisions for each data sample when performing abductive reasoning. If float, denotes the fraction of the total @@ -51,7 +51,7 @@ class Reasoner: self, kb: KBBase, dist_func: Union[str, Callable] = "confidence", - mapping: Optional[dict] = None, + idx_to_label: Optional[dict] = None, max_revision: Union[int, float] = -1, require_more_revision: int = 0, use_zoopt: bool = False, @@ -63,12 +63,12 @@ class Reasoner: self.max_revision = max_revision self.require_more_revision = require_more_revision - if mapping is None: - self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} + if idx_to_label is None: + self.idx_to_label = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} else: - self._check_valid_mapping(mapping) - self.mapping = mapping - self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) + self._check_valid_idx_to_label(idx_to_label) + self.idx_to_label = idx_to_label + self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys())) def _check_valid_dist(self, dist_func): if isinstance(dist_func, str): @@ -87,15 +87,15 @@ class Reasoner: f"dist_func must be a string or a callable function, but got {type(dist_func)}." ) - def _check_valid_mapping(self, mapping): - if not isinstance(mapping, dict): - raise TypeError(f"mapping should be dict, but got {type(mapping)}.") - for key, value in mapping.items(): + def _check_valid_idx_to_label(self, idx_to_label): + if not isinstance(idx_to_label, dict): + raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.") + for key, value in idx_to_label.items(): if not isinstance(key, int): - raise ValueError(f"All keys in the mapping must be integers, but got {key}.") + raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.") if value not in self.kb.pseudo_label_list: raise ValueError( - f"All values in the mapping must be in the pseudo_label_list, but got {value}." + f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}." ) def _get_one_candidate( @@ -158,10 +158,10 @@ class Reasoner: if self.dist_func == "hamming": return hamming_dist(data_sample.pred_pseudo_label, candidates) elif self.dist_func == "confidence": - candidates = [[self.remapping[x] for x in c] for c in candidates] + candidates = [[self.label_to_idx[x] for x in c] for c in candidates] return confidence_dist(data_sample.pred_prob, candidates) else: - candidate_idxs = [[self.remapping[x] for x in c] for c in candidates] + candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] cost_list = self.dist_func(data_sample, candidates, candidate_idxs, reasoning_results) if len(cost_list) != len(candidates): raise ValueError( diff --git a/examples/hed/hed_bridge.py b/examples/hed/hed_bridge.py index 1bbca0c..04ea1f5 100644 --- a/examples/hed/hed_bridge.py +++ b/examples/hed/hed_bridge.py @@ -67,8 +67,8 @@ class HEDBridge(SimpleBridge): mapping_score = [] abduced_pseudo_label_list = [] for _mapping in candidate_mappings: - self.reasoner.mapping = _mapping - self.reasoner.remapping = dict(zip(_mapping.values(), _mapping.keys())) + self.reasoner.idx_to_label = _mapping + self.reasoner.label_to_idx = dict(zip(_mapping.values(), _mapping.keys())) self.idx_to_pseudo_label(data_samples) abduced_pseudo_label = self.reasoner.abduce(data_samples) mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([])) @@ -76,9 +76,9 @@ class HEDBridge(SimpleBridge): max_revisible_instances = max(mapping_score) return_idx = mapping_score.index(max_revisible_instances) - self.reasoner.mapping = candidate_mappings[return_idx] - self.reasoner.remapping = dict( - zip(self.reasoner.mapping.values(), self.reasoner.mapping.keys()) + self.reasoner.idx_to_label = candidate_mappings[return_idx] + self.reasoner.label_to_idx = dict( + zip(self.reasoner.idx_to_label.values(), self.reasoner.idx_to_label.keys()) ) self.idx_to_pseudo_label(data_samples) data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx] @@ -236,7 +236,7 @@ class HEDBridge(SimpleBridge): else: if equation_len == min_len: print_log( - "Learned mapping is: " + str(self.reasoner.mapping), + "Learned mapping is: " + str(self.reasoner.idx_to_label), logger="current", ) self.model.load(load_path="./weights/pretrain_weights.pth")