diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index 03054f7..b535211 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -1,52 +1,64 @@ from abc import ABCMeta, abstractmethod -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Optional, Union from ..learning import ABLModel from ..reasoning import ReasonerBase +from ..structures import ListData -class BaseBridge(metaclass=ABCMeta): +DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] + +class BaseBridge(metaclass=ABCMeta): def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: if not isinstance(model, ABLModel): - raise TypeError("Expected an ABLModel") + raise TypeError( + "Expected an instance of ABLModel, but received type: {}".format( + type(model) + ) + ) if not isinstance(abducer, ReasonerBase): - raise TypeError("Expected an ReasonerBase") - + raise TypeError( + "Expected an instance of ReasonerBase, but received type: {}".format( + type(abducer) + ) + ) + self.model = model self.abducer = abducer @abstractmethod - def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: + def predict( + self, data_samples: ListData + ) -> Tuple[List[List[Any]], List[List[Any]]]: """Placeholder for predict labels from input.""" pass @abstractmethod - def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for abduce pseudo labels.""" @abstractmethod - def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: + def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for map label space to symbol space.""" pass @abstractmethod - def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for map symbol space to label space.""" pass - + @abstractmethod - def train(self, train_data): + def train(self, train_data: Union[ListData, DataSet]): """Placeholder for train loop of ABductive Learning.""" pass @abstractmethod - def test(self, test_data): + def valid(self, valid_data: Union[ListData, DataSet]) -> None: """Placeholder for model test.""" pass @abstractmethod - def valid(self, valid_data): + def test(self, test_data: Union[ListData, DataSet]) -> None: """Placeholder for model validation.""" pass - \ No newline at end of file diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 3aecffd..4ca5628 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -1,12 +1,13 @@ -from ..learning import ABLModel -from ..reasoning import ReasonerBase -from ..evaluation import BaseMetric -from .base_bridge import BaseBridge from typing import List, Union, Any, Tuple, Dict, Optional + from numpy import ndarray -from torch.utils.data import DataLoader -from ..dataset import BridgeDataset +from .base_bridge import BaseBridge, DataSet + +from ..learning import ABLModel +from ..reasoning import ReasonerBase +from ..evaluation import BaseMetric +from ..structures import ListData from ..utils.logger import print_log @@ -20,64 +21,77 @@ class SimpleBridge(BaseBridge): super().__init__(model, abducer) self.metric_list = metric_list - def predict(self, X) -> Tuple[List[List[Any]], ndarray]: - pred_res = self.model.predict(X) - pred_idx, pred_prob = pred_res["label"], pred_res["prob"] - return pred_idx, pred_prob - + def predict(self, data_samples: ListData) -> Tuple[List[ndarray], ndarray]: + pred_res = self.model.predict(data_samples) + data_samples.pred_idx = pred_res["label"] + data_samples.pred_prob = pred_res["prob"] + return data_samples["pred_idx"], ["data_samples.pred_prob"] + def abduce_pseudo_label( self, - pred_prob: ndarray, - pred_pseudo_label: List[List[Any]], - Y: List[Any], + data_samples: ListData, max_revision: int = -1, require_more_revision: int = 0, ) -> List[List[Any]]: - return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) + self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) + return data_samples["abduced_pseudo_label"] def idx_to_pseudo_label( - self, idx: List[List[Any]], mapping: Dict = None + self, data_samples: ListData, mapping: Dict = None ) -> List[List[Any]]: if mapping is None: mapping = self.abducer.mapping - return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] + pred_idx = data_samples.pred_idx + data_samples.pred_pseudo_label = [ + [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx + ] + return data_samples["pred_pseudo_label"] def pseudo_label_to_idx( - self, pseudo_label: List[List[Any]], mapping: Dict = None + self, data_samples: ListData, mapping: Dict = None ) -> List[List[Any]]: if mapping is None: mapping = self.abducer.remapping - return [ - [mapping[_pseudo_label] for _pseudo_label in sub_list] - for sub_list in pseudo_label + abduced_idx = [ + [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] + for sub_list in data_samples.abduced_pseudo_label ] + data_samples.abduced_idx = abduced_idx + return data_samples["abduced_idx"] + + def data_preprocess( + self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any] + ) -> ListData: + data_samples = ListData() + + data_samples.X = X + data_samples.gt_pseudo_label = gt_pseudo_label + data_samples.Y = Y + + return data_samples def train( self, - train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], + train_data: DataSet, epochs: int = 50, batch_size: Union[int, float] = -1, eval_interval: int = 1, ): - dataset = BridgeDataset(*train_data) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], - ) + data_samples = self.data_preprocess(*train_data) for epoch in range(epochs): - for seg_idx, (X, Z, Y) in enumerate(data_loader): - pred_idx, pred_prob = self.predict(X) - pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) - abduced_pseudo_label = self.abduce_pseudo_label( - pred_prob, pred_pseudo_label, Y - ) - abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) - loss = self.model.train(X, abduced_label) + for seg_idx in range((len(data_samples) - 1) // batch_size + 1): + sub_data_samples = data_samples[ + seg_idx * batch_size : (seg_idx + 1) * batch_size + ] + self.predict(sub_data_samples) + self.idx_to_pseudo_label(sub_data_samples) + self.abduce_pseudo_label(sub_data_samples) + self.pseudo_label_to_idx(sub_data_samples) + loss = self.model.train(sub_data_samples) print_log( - f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", + f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{(len(data_samples) - 1) // batch_size + 1}] model loss is {loss:.5f}", logger="current", ) @@ -85,20 +99,19 @@ class SimpleBridge(BaseBridge): print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") self.valid(train_data) - def _valid(self, data_loader): - for X, Z, Y in data_loader: - pred_idx, pred_prob = self.predict(X) - pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) - data_samples = dict( - pred_idx=pred_idx, - pred_prob=pred_prob, - pred_pseudo_label=pred_pseudo_label, - gt_pseudo_label=Z, - Y=Y, - logic_forward=self.abducer.kb.logic_forward, + def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: + for seg_idx in range((len(data_samples) - 1) // batch_size + 1): + sub_data_samples = data_samples[ + seg_idx * batch_size : (seg_idx + 1) * batch_size + ] + self.predict(sub_data_samples) + self.idx_to_pseudo_label(sub_data_samples) + + sub_data_samples.set_metainfo( + dict(logic_forward=self.abducer.kb.logic_forward) ) for metric in self.metric_list: - metric.process(data_samples) + metric.process(sub_data_samples) res = dict() for metric in self.metric_list: @@ -108,14 +121,12 @@ class SimpleBridge(BaseBridge): msg += k + f": {v:.3f} " print_log(msg, logger="current") - def valid(self, valid_data, batch_size=1000): - dataset = BridgeDataset(*valid_data) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], - ) - self._valid(data_loader) - - def test(self, test_data, batch_size=1000): - self.valid(test_data, batch_size) + def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None: + if not isinstance(valid_data, ListData): + data_samples = self.data_preprocess(*valid_data) + else: + data_samples = valid_data + self._valid(data_samples, batch_size=batch_size) + + def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None: + self.valid(test_data, batch_size=batch_size) diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 3333daf..1bacca4 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,8 +1,6 @@ from typing import Optional, Sequence from .base_metric import BaseMetric -class ABLMetric(): - pass class SemanticsMetric(BaseMetric): def __init__(self, prefix: Optional[str] = None) -> None: diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 8c8b2f9..ae853bc 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -9,9 +9,13 @@ # Description : # # ================================================================# +from typing import List, Any, Optional + import pickle + +from ..structures import ListData from ..utils import flatten, reform_idx -from typing import List, Any, Optional + class ABLModel: @@ -55,7 +59,7 @@ class ABLModel: "base_model should have fit, predict and score methods." ) - def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: + def predict(self, data_samples: ListData, mapping: Optional[dict] = None) -> dict: """ Predict the labels and probabilities for the given data. @@ -72,11 +76,11 @@ class ABLModel: A dictionary containing the predicted labels and probabilities. """ model = self.classifier_list[0] - data_X = flatten(X) + data_X = flatten(data_samples["X"]) if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) - prob = reform_idx(prob, X) + prob = reform_idx(prob, data_samples["X"]) else: prob = None label = model.predict(X=data_X) @@ -84,7 +88,7 @@ class ABLModel: if mapping is not None: label = [mapping[y] for y in label] - label = reform_idx(label, X) + label = reform_idx(label, data_samples["X"]) return {"label": label, "prob": prob} @@ -109,7 +113,7 @@ class ABLModel: score = self.classifier_list[0].score(X=data_X, y=data_Y) return score - def train(self, X: List[List[Any]], Y: List[Any]) -> float: + def train(self, data_samples: ListData) -> float: """ Train the model on the given data. @@ -125,9 +129,9 @@ class ABLModel: float The loss value of the trained model. """ - data_X = flatten(X) - data_Y = flatten(Y) - return self.classifier_list[0].fit(X=data_X, y=data_Y) + data_X = flatten(data_samples["X"]) + data_y = flatten(data_samples["abduced_idx"]) + return self.classifier_list[0].fit(X=data_X, y=data_y) def _model_operation(self, operation: str, *args, **kwargs): model = self.classifier_list[0] diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index b551df0..614e454 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -5,13 +5,21 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list +from ..utils.utils import ( + flatten, + reform_idx, + hamming_dist, + check_equal, + to_hashable, + hashable_to_list, +) from multiprocessing import Pool from functools import lru_cache import pyswip + class KBBase(ABC): def __init__(self, pseudo_label_list, max_err=0, use_cache=True): # TODO:添加一下类型检查,比如 @@ -20,7 +28,7 @@ class KBBase(ABC): self.pseudo_label_list = pseudo_label_list self.max_err = max_err - self.use_cache = use_cache + self.use_cache = use_cache @abstractmethod def logic_forward(self, pseudo_labels): @@ -28,10 +36,17 @@ class KBBase(ABC): def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): if not self.use_cache: - return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) - else: - return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision) - + return self._abduce_by_search( + pred_res, y, max_revision_num, require_more_revision + ) + else: + return self._abduce_by_search_cache( + to_hashable(pred_res), + to_hashable(y), + max_revision_num, + require_more_revision, + ) + def revise_by_idx(self, pred_res, y, revision_idx): candidates = [] abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) @@ -52,10 +67,12 @@ class KBBase(ABC): new_candidates.extend(candidates) return new_candidates - def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): + def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): candidates = [] for revision_num in range(len(pred_res) + 1): - if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): + if revision_num == 0 and check_equal( + self.logic_forward(pred_res), y, self.max_err + ): candidates.append(pred_res) elif revision_num > 0: candidates.extend(self._revision(revision_num, pred_res, y)) @@ -65,18 +82,24 @@ class KBBase(ABC): if revision_num >= max_revision_num: return [] - for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): + for revision_num in range( + min_revision_num + 1, min_revision_num + require_more_revision + 1 + ): if revision_num > max_revision_num: return candidates candidates.extend(self._revision(revision_num, pred_res, y)) return candidates - + @lru_cache(maxsize=None) - def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision): + def _abduce_by_search_cache( + self, pred_res, y, max_revision_num, require_more_revision + ): pred_res = hashable_to_list(pred_res) y = hashable_to_list(y) - return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) - + return self._abduce_by_search( + pred_res, y, max_revision_num, require_more_revision + ) + def _dict_len(self, dic): if not self.GKB_flag: return 0 @@ -88,17 +111,18 @@ class KBBase(ABC): return 0 else: return sum(self._dict_len(v) for v in self.base.values()) - + + class ground_KB(KBBase): def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): super().__init__(pseudo_label_list, max_err) - + self.GKB_len_list = GKB_len_list self.base = {} X, Y = self._get_GKB() for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(x) - + # For parallel version of _get_GKB def _get_XY_list(self, args): pre_x, post_x_it = args[0], args[1] @@ -114,6 +138,7 @@ class ground_KB(KBBase): def _get_GKB(self): X, Y = [], [] for length in self.GKB_len_list: + print("Generating GKB of length %d" % length) arg_list = [] for pre_x in self.pseudo_label_list: post_x_it = product(self.pseudo_label_list, repeat=length - 1) @@ -126,21 +151,24 @@ class ground_KB(KBBase): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) - if Y and isinstance(Y[0], (int, float)): + if Y and isinstance(Y[0], (int, float)): X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) return X, Y - - def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): - return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision) - - def _find_candidate_GKB(self, pred_res, y): + + def abduce_candidates(self, data_sample, max_revision_num, require_more_revision=0): + return self._abduce_by_GKB( + data_sample, max_revision_num, require_more_revision=require_more_revision + ) + + def _find_candidate_GKB(self, cache_key, data_sample): + y = data_sample["Y"][0] if self.max_err == 0: - return self.base[len(pred_res)][y] + return self.base[cache_key][y] else: - potential_candidates = self.base[len(pred_res)] + potential_candidates = self.base[cache_key] key_list = list(potential_candidates.keys()) key_idx = bisect.bisect_left(key_list, y) - + all_candidates = [] for idx in range(key_idx - 1, -1, -1): k = key_list[idx] @@ -148,7 +176,7 @@ class ground_KB(KBBase): all_candidates.extend(potential_candidates[k]) else: break - + for idx in range(key_idx, len(key_list)): k = key_list[idx] if abs(k - y) <= self.max_err: @@ -156,19 +184,20 @@ class ground_KB(KBBase): else: break return all_candidates - - def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): - if self.base == {} or len(pred_res) not in self.GKB_len_list: + + def _abduce_by_GKB(self, data_sample, max_revision_num, require_more_revision=0): + cache_key = len(data_sample["pred_pseudo_label"][0]) + if self.base == {} or cache_key not in self.GKB_len_list: return [] - - all_candidates = self._find_candidate_GKB(pred_res, y) + + all_candidates = self._find_candidate_GKB(cache_key, data_sample) if len(all_candidates) == 0: return [] - cost_list = hamming_dist(pred_res, all_candidates) - min_revision_num = np.min(cost_list) + cost_array = hamming_dist(data_sample["pred_pseudo_label"][0], all_candidates) + min_revision_num = np.min(cost_array) revision_num = min(max_revision_num, min_revision_num + require_more_revision) - idxs = np.where(cost_list <= revision_num)[0] + idxs = np.where(cost_array <= revision_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates @@ -180,33 +209,38 @@ class prolog_KB(KBBase): self.prolog.consult(pl_file) def logic_forward(self, pseudo_labels): - result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] - if result == 'true': + result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0][ + "Res" + ] + if result == "true": return True - elif result == 'false': + elif result == "false": return False return result - + def _revision_pred_res(self, pred_res, revision_idx): import re + revision_pred_res = pred_res.copy() revision_pred_res = flatten(revision_pred_res) - + for idx in revision_idx: - revision_pred_res[idx] = 'P' + str(idx) + revision_pred_res[idx] = "P" + str(idx) revision_pred_res = reform_idx(revision_pred_res, pred_res) - + # TODO:不知道有没有更简洁的方法 regex = r"'P\d+'" - return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res)) - + return re.sub( + regex, lambda x: x.group().replace("'", ""), str(revision_pred_res) + ) + def get_query_string(self, pred_res, y, revision_idx): query_string = "logic_forward(" query_string += self._revision_pred_res(pred_res, revision_idx) key_is_none_flag = y is None or (type(y) == list and y[0] is None) query_string += ",%s)." % y if not key_is_none_flag else ")." return query_string - + def revise_by_idx(self, pred_res, y, revision_idx): candidates = [] query_string = self.get_query_string(pred_res, y, revision_idx) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 968b19a..d1ca4bb 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -70,7 +70,7 @@ class ReasonerBase: candidates = [[self.remapping[x] for x in c] for c in candidates] return confidence_dist(pred_prob, candidates) - def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): + def _get_one_candidate(self, data_sample, candidates): """ Get one candidate. If multiple candidates exist, return the one with minimum cost. @@ -94,7 +94,9 @@ class ReasonerBase: elif len(candidates) == 1: return candidates[0] else: - cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) + cost_array = self._get_cost_list( + data_sample["pred_pseudo_label"][0], data_sample["pred_prob"][0], candidates + ) candidate = candidates[np.argmin(cost_array)] return candidate @@ -188,9 +190,7 @@ class ReasonerBase: """ return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) - def abduce( - self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 - ): + def abduce(self, data_sample, max_revision=-1, require_more_revision=0): """ Perform revision by abduction on the given data. @@ -213,26 +213,24 @@ class ReasonerBase: list The abduced revisions. """ - symbol_num = len(flatten(pred_pseudo_label)) + symbol_num = len(flatten(data_sample.pred_pseudo_label)) max_revision_num = calculate_revision_num(max_revision, symbol_num) if self.use_zoopt: solution = self.zoopt_get_solution( - symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num + symbol_num, data_sample, max_revision_num ) revision_idx = np.where(solution != 0)[0] - candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) + candidates = self.revise_by_idx(data_sample, revision_idx) else: candidates = self.kb.abduce_candidates( - pred_pseudo_label, y, max_revision_num, require_more_revision + data_sample, max_revision_num, require_more_revision ) - candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) + candidate = self._get_one_candidate(data_sample, candidates) return candidate - def batch_abduce( - self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 - ): + def batch_abduce(self, data_samples, max_revision=-1, require_more_revision=0): """ Perform abduction on the given data in batches. @@ -255,14 +253,15 @@ class ReasonerBase: list The abduced revisions in batches. """ - return [ + abduced_pseudo_label = [ self.abduce( - _pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision - ) - for _pred_prob, _pred_pseudo_label, _Y in zip( - pred_prob, pred_pseudo_label, Y + data_sample, + max_revision=max_revision, + require_more_revision=require_more_revision, ) + for data_sample in data_samples ] + data_samples.abduced_pseudo_label = abduced_pseudo_label # def _batch_abduce_helper(self, args): # z, prob, y, max_revision, require_more_revision = args @@ -281,43 +280,57 @@ class ReasonerBase: ) - - if __name__ == "__main__": from kb import KBBase, ground_KB, prolog_KB - prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] - - prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + prob1 = [ + [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + ] + + prob2 = [ + [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + ] class add_KB(KBBase): - def __init__(self, pseudo_label_list=list(range(10)), - use_cache=True): + def __init__(self, pseudo_label_list=list(range(10)), use_cache=True): super().__init__(pseudo_label_list, use_cache=use_cache) def logic_forward(self, nums): return sum(nums) - + class add_ground_KB(ground_KB): - def __init__(self, pseudo_label_list=list(range(10)), - GKB_len_list=[2]): + def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): super().__init__(pseudo_label_list, GKB_len_list) def logic_forward(self, nums): return sum(nums) - + def test_add(reasoner): - res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + res = reasoner.batch_abduce( + prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 + ) print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce( + prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 + ) print(res) print() @@ -337,8 +350,9 @@ if __name__ == "__main__": test_add(reasoner) print("prolog_KB with add.pl:") - kb = prolog_KB(pseudo_label_list=list(range(10)), - pl_file="examples/mnist_add/datasets/add.pl") + kb = prolog_KB( + pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl" + ) reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) @@ -351,14 +365,16 @@ if __name__ == "__main__": test_add(reasoner) print("add_KB with multiple inputs at once:") - multiple_prob = [[ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ], - [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ]] + multiple_prob = [ + [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ], + [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ], + ] kb = add_KB() reasoner = ReasonerBase(kb, "confidence") @@ -383,8 +399,21 @@ if __name__ == "__main__": class HWF_KB(KBBase): def __init__( self, - pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", - "+", "-", "times", "div"], + pseudo_label_list=[ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "+", + "-", + "times", + "div", + ], max_err=1e-3, ): super().__init__(pseudo_label_list, max_err) @@ -393,7 +422,17 @@ if __name__ == "__main__": if len(formula) % 2 == 0: return False for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: + if i % 2 == 0 and formula[i] not in [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ]: return False if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: return False @@ -406,12 +445,25 @@ if __name__ == "__main__": mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) formula = [mapping[f] for f in formula] return eval("".join(formula)) - + class HWF_ground_KB(ground_KB): def __init__( self, - pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", - "+", "-", "times", "div"], + pseudo_label_list=[ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "+", + "-", + "times", + "div", + ], GKB_len_list=[1, 3, 5, 7], max_err=1e-3, ): @@ -421,7 +473,17 @@ if __name__ == "__main__": if len(formula) % 2 == 0: return False for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: + if i % 2 == 0 and formula[i] not in [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ]: return False if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: return False @@ -434,7 +496,7 @@ if __name__ == "__main__": mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) formula = [mapping[f] for f in formula] return eval("".join(formula)) - + def test_hwf(reasoner): res = reasoner.batch_abduce( [None], @@ -461,7 +523,7 @@ if __name__ == "__main__": ) print(res) print() - + def test_hwf_multiple(reasoner, max_revisions): res = reasoner.batch_abduce( [None, None], @@ -512,10 +574,10 @@ if __name__ == "__main__": print("HWF_KB with multiple inputs at once:") kb = HWF_KB(max_err=0.1) reasoner = ReasonerBase(kb, "hamming") - test_hwf_multiple(reasoner, max_revisions=[1,3,3]) - + test_hwf_multiple(reasoner, max_revisions=[1, 3, 3]) + print("max_revision is float") - test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) + test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9]) class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list, pl_file): diff --git a/abl/structures/__init__.py b/abl/structures/__init__.py new file mode 100644 index 0000000..52b5af3 --- /dev/null +++ b/abl/structures/__init__.py @@ -0,0 +1,2 @@ +from .base_data_element import BaseDataElement +from .list_data import ListData \ No newline at end of file diff --git a/abl/structures/base_data_element.py b/abl/structures/base_data_element.py new file mode 100644 index 0000000..f88dd1d --- /dev/null +++ b/abl/structures/base_data_element.py @@ -0,0 +1,629 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Any, Iterator, Optional, Tuple, Type, Union + +import numpy as np +import torch + + +class BaseDataElement: + """A base data interface that supports Tensor-like and dict-like + operations. + + A typical data elements refer to predicted results or ground truth labels + on a task, such as predicted bboxes, instance masks, semantic + segmentation masks, etc. Because groundtruth labels and predicted results + often have similar properties (for example, the predicted bboxes and the + groundtruth bboxes), MMEngine uses the same abstract data interface to + encapsulate predicted results and groundtruth labels, and it is recommended + to use different name conventions to distinguish them, such as using + ``gt_instances`` and ``pred_instances`` to distinguish between labels and + predicted results. Additionally, we distinguish data elements at instance + level, pixel level, and label level. Each of these types has its own + characteristics. Therefore, MMEngine defines the base class + ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and + ``LabelData`` inheriting from ``BaseDataElement`` to represent different + types of ground truth labels or predictions. + + Another common data element is sample data. A sample data consists of input + data (such as an image) and its annotations and predictions. In general, + an image can have multiple types of annotations and/or predictions at the + same time (for example, both pixel-level semantic segmentation annotations + and instance-level detection bboxes annotations). All labels and + predictions of a training sample are often passed between Dataset, Model, + Visualizer, and Evaluator components. In order to simplify the interface + between components, we can treat them as a large data element and + encapsulate them. Such data elements are generally called XXDataSample in + the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` + allows `BaseDataElement` as its attribute. Such a class generally + encapsulates all the data of a sample in the algorithm library, and its + attributes generally are various types of data elements. For example, + MMDetection is assigned by the BaseDataElement to encapsulate all the data + elements of the sample labeling and prediction of a sample in the + algorithm library. + + The attributes in ``BaseDataElement`` are divided into two parts, + the ``metainfo`` and the ``data`` respectively. + + - ``metainfo``: Usually contains the + information about the image such as filename, + image_shape, pad_shape, etc. The attributes can be accessed or + modified by dict-like or object-like operations, such as + ``.`` (for data access and modification), ``in``, ``del``, + ``pop(str)``, ``get(str)``, ``metainfo_keys()``, + ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for + set or change key-value pairs in metainfo). + + - ``data``: Annotations or model predictions are + stored. The attributes can be accessed or modified by + dict-like or object-like operations, such as + ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, + ``values()``, ``items()``. Users can also apply tensor-like + methods to all :obj:`torch.Tensor` in the ``data_fields``, + such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, + ``to_tensor()``, ``.detach()``. + + Args: + metainfo (dict, optional): A dict contains the meta information + of single image, such as ``dict(img_shape=(512, 512, 3), + scale_factor=(1, 1, 1, 1))``. Defaults to None. + kwargs (dict, optional): A dict contains annotations of single image or + model predictions. Defaults to None. + + Examples: + >>> import torch + >>> from mmengine.structures import BaseDataElement + >>> gt_instances = BaseDataElement() + >>> bboxes = torch.rand((5, 4)) + >>> scores = torch.rand((5,)) + >>> img_id = 0 + >>> img_shape = (800, 1333) + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=img_id, img_shape=img_shape), + ... bboxes=bboxes, scores=scores) + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) + + >>> # new + >>> gt_instances1 = gt_instances.new( + ... metainfo=dict(img_id=1, img_shape=(640, 640)), + ... bboxes=torch.rand((5, 4)), + ... scores=torch.rand((5,))) + >>> gt_instances2 = gt_instances1.new() + + >>> # add and process property + >>> gt_instances = BaseDataElement() + >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) + >>> assert 'img_shape' in gt_instances.metainfo_keys() + >>> assert 'img_shape' in gt_instances + >>> assert 'img_shape' not in gt_instances.keys() + >>> assert 'img_shape' in gt_instances.all_keys() + >>> print(gt_instances.img_shape) + (100, 100) + >>> gt_instances.scores = torch.rand((5,)) + >>> assert 'scores' in gt_instances.keys() + >>> assert 'scores' in gt_instances + >>> assert 'scores' in gt_instances.all_keys() + >>> assert 'scores' not in gt_instances.metainfo_keys() + >>> print(gt_instances.scores) + tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) + >>> gt_instances.bboxes = torch.rand((5, 4)) + >>> assert 'bboxes' in gt_instances.keys() + >>> assert 'bboxes' in gt_instances + >>> assert 'bboxes' in gt_instances.all_keys() + >>> assert 'bboxes' not in gt_instances.metainfo_keys() + >>> print(gt_instances.bboxes) + tensor([[0.0900, 0.0424, 0.1755, 0.4469], + [0.8648, 0.0592, 0.3484, 0.0913], + [0.5808, 0.1909, 0.6165, 0.7088], + [0.5490, 0.4209, 0.9416, 0.2374], + [0.3652, 0.1218, 0.8805, 0.7523]]) + + >>> # delete and change property + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=0, img_shape=(640, 640)), + ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) + >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) + >>> gt_instances.img_shape # (1280, 1280) + >>> gt_instances.bboxes = gt_instances.bboxes * 2 + >>> gt_instances.get('img_shape', None) # (1280, 1280) + >>> gt_instances.get('bboxes', None) # 6x4 tensor + >>> del gt_instances.img_shape + >>> del gt_instances.bboxes + >>> assert 'img_shape' not in gt_instances + >>> assert 'bboxes' not in gt_instances + >>> gt_instances.pop('img_shape', None) # None + >>> gt_instances.pop('bboxes', None) # None + + >>> # Tensor-like + >>> cuda_instances = gt_instances.cuda() + >>> cuda_instances = gt_instances.to('cuda:0') + >>> cpu_instances = cuda_instances.cpu() + >>> cpu_instances = cuda_instances.to('cpu') + >>> fp16_instances = cuda_instances.to( + ... device=None, dtype=torch.float16, non_blocking=False, + ... copy=False, memory_format=torch.preserve_format) + >>> cpu_instances = cuda_instances.detach() + >>> np_instances = cpu_instances.numpy() + + >>> # print + >>> metainfo = dict(img_shape=(800, 1196, 3)) + >>> gt_instances = BaseDataElement( + ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + >>> sample = BaseDataElement(metainfo=metainfo, + ... gt_instances=gt_instances) + >>> print(sample) + + ) at 0x7f0fea49e130> + + >>> # inheritance + >>> class DetDataSample(BaseDataElement): + ... @property + ... def proposals(self): + ... return self._proposals + ... @proposals.setter + ... def proposals(self, value): + ... self.set_field(value, '_proposals', dtype=BaseDataElement) + ... @proposals.deleter + ... def proposals(self): + ... del self._proposals + ... @property + ... def gt_instances(self): + ... return self._gt_instances + ... @gt_instances.setter + ... def gt_instances(self, value): + ... self.set_field(value, '_gt_instances', + ... dtype=BaseDataElement) + ... @gt_instances.deleter + ... def gt_instances(self): + ... del self._gt_instances + ... @property + ... def pred_instances(self): + ... return self._pred_instances + ... @pred_instances.setter + ... def pred_instances(self, value): + ... self.set_field(value, '_pred_instances', + ... dtype=BaseDataElement) + ... @pred_instances.deleter + ... def pred_instances(self): + ... del self._pred_instances + >>> det_sample = DetDataSample() + >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) + >>> det_sample.proposals = proposals + >>> assert 'proposals' in det_sample + >>> assert det_sample.proposals == proposals + >>> del det_sample.proposals + >>> assert 'proposals' not in det_sample + >>> with self.assertRaises(AssertionError): + ... det_sample.proposals = torch.rand((5, 4)) + """ + + def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: + self._metainfo_fields: set = set() + self._data_fields: set = set() + + if metainfo is not None: + self.set_metainfo(metainfo=metainfo) + if kwargs: + self.set_data(kwargs) + + def set_metainfo(self, metainfo: dict) -> None: + """Set or change key-value pairs in ``metainfo_field`` by parameter + ``metainfo``. + + Args: + metainfo (dict): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + """ + assert isinstance( + metainfo, dict + ), f"metainfo should be a ``dict`` but got {type(metainfo)}" + # meta = copy.deepcopy(metainfo) + for k, v in metainfo.items(): + self.set_field(name=k, value=v, field_type="metainfo", dtype=None) + + def set_data(self, data: dict) -> None: + """Set or change key-value pairs in ``data_field`` by parameter + ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. + """ + assert isinstance(data, dict), f"data should be a `dict` but got {data}" + for k, v in data.items(): + # Use `setattr()` rather than `self.set_field` to allow `set_data` + # to set property method. + setattr(self, k, v) + + def update(self, instance: "BaseDataElement") -> None: + """The update() method updates the BaseDataElement with the elements + from another BaseDataElement object. + + Args: + instance (BaseDataElement): Another BaseDataElement object for + update the current object. + """ + assert isinstance( + instance, BaseDataElement + ), f"instance should be a `BaseDataElement` but got {type(instance)}" + self.set_metainfo(dict(instance.metainfo_items())) + self.set_data(dict(instance.items())) + + def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": + """Return a new data element with same type. If ``metainfo`` and + ``data`` are None, the new data element will have same metainfo and + data. If metainfo or data is not None, the new result will overwrite it + with the input value. + + Args: + metainfo (dict, optional): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + Defaults to None. + kwargs (dict): A dict contains annotations of image or + model predictions. + + Returns: + BaseDataElement: A new data element with same type. + """ + new_data = self.__class__() + + if metainfo is not None: + new_data.set_metainfo(metainfo) + else: + new_data.set_metainfo(dict(self.metainfo_items())) + if kwargs: + new_data.set_data(kwargs) + else: + new_data.set_data(dict(self.items())) + return new_data + + def clone(self): + """Deep copy the current data element. + + Returns: + BaseDataElement: The copy of current data element. + """ + clone_data = self.__class__() + clone_data.set_metainfo(dict(self.metainfo_items())) + clone_data.set_data(dict(self.items())) + return clone_data + + def keys(self) -> list: + """ + Returns: + list: Contains all keys in data_fields. + """ + # We assume that the name of the attribute related to property is + # '_' + the name of the property. We use this rule to filter out + # private keys. + # TODO: Use a more robust way to solve this problem + private_keys = { + "_" + key + for key in self._data_fields + if isinstance(getattr(type(self), key, None), property) + } + return list(self._data_fields - private_keys) + + def metainfo_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo_fields. + """ + return list(self._metainfo_fields) + + def values(self) -> list: + """ + Returns: + list: Contains all values in data. + """ + return [getattr(self, k) for k in self.keys()] + + def metainfo_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo. + """ + return [getattr(self, k) for k in self.metainfo_keys()] + + def all_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo and data. + """ + return self.metainfo_keys() + self.keys() + + def all_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo and data. + """ + return self.metainfo_values() + self.values() + + def all_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``metainfo`` and ``data``. + """ + for k in self.all_keys(): + yield (k, getattr(self, k)) + + def items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``data``. + """ + for k in self.keys(): + yield (k, getattr(self, k)) + + def metainfo_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``metainfo``. + """ + for k in self.metainfo_keys(): + yield (k, getattr(self, k)) + + @property + def metainfo(self) -> dict: + """dict: A dict contains metainfo of current data element.""" + return dict(self.metainfo_items()) + + def __setattr__(self, name: str, value: Any): + """setattr is only used to set data.""" + if name in ("_metainfo_fields", "_data_fields"): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f"{name} has been used as a " + "private attribute, which is immutable." + ) + else: + self.set_field(name=name, value=value, field_type="data", dtype=None) + + def __delattr__(self, item: str): + """Delete the item in dataelement. + + Args: + item (str): The key to delete. + """ + if item in ("_metainfo_fields", "_data_fields"): + raise AttributeError( + f"{item} has been used as a " "private attribute, which is immutable." + ) + super().__delattr__(item) + if item in self._metainfo_fields: + self._metainfo_fields.remove(item) + elif item in self._data_fields: + self._data_fields.remove(item) + + # dict-like methods + __delitem__ = __delattr__ + + def get(self, key, default=None) -> Any: + """Get property in data and metainfo as the same as python.""" + # Use `getattr()` rather than `self.__dict__.get()` to allow getting + # properties. + return getattr(self, key, default) + + def pop(self, *args) -> Any: + """Pop property in data and metainfo as the same as python.""" + assert len(args) < 3, "``pop`` get more than 2 arguments" + name = args[0] + if name in self._metainfo_fields: + self._metainfo_fields.remove(args[0]) + return self.__dict__.pop(*args) + + elif name in self._data_fields: + self._data_fields.remove(args[0]) + return self.__dict__.pop(*args) + + # with default value + elif len(args) == 2: + return args[1] + else: + # don't just use 'self.__dict__.pop(*args)' for only popping key in + # metainfo or data + raise KeyError(f"{args[0]} is not contained in metainfo or data") + + def __contains__(self, item: str) -> bool: + """Whether the item is in dataelement. + + Args: + item (str): The key to inquire. + """ + return item in self._data_fields or item in self._metainfo_fields + + def set_field( + self, + value: Any, + name: str, + dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, + field_type: str = "data", + ) -> None: + """Special method for set union field, used as property.setter + functions.""" + assert field_type in ["metainfo", "data"] + if dtype is not None: + assert isinstance( + value, dtype + ), f"{value} should be a {dtype} but got {type(value)}" + + if field_type == "metainfo": + if name in self._data_fields: + raise AttributeError( + f"Cannot set {name} to be a field of metainfo " + f"because {name} is already a data field" + ) + self._metainfo_fields.add(name) + else: + if name in self._metainfo_fields: + raise AttributeError( + f"Cannot set {name} to be a field of data " + f"because {name} is already a metainfo field" + ) + self._data_fields.add(name) + super().__setattr__(name, value) + + # Tensor-like methods + def to(self, *args, **kwargs) -> "BaseDataElement": + """Apply same name function to all tensors in data_fields.""" + new_data = self.new() + for k, v in self.items(): + if hasattr(v, "to"): + v = v.to(*args, **kwargs) + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def cpu(self) -> "BaseDataElement": + """Convert all tensors to CPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cpu() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def cuda(self) -> "BaseDataElement": + """Convert all tensors to GPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cuda() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def npu(self) -> "BaseDataElement": + """Convert all tensors to NPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.npu() + data = {k: v} + new_data.set_data(data) + return new_data + + def mlu(self) -> "BaseDataElement": + """Convert all tensors to MLU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.mlu() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def detach(self) -> "BaseDataElement": + """Detach all tensors in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def numpy(self) -> "BaseDataElement": + """Convert all tensors to np.ndarray in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach().cpu().numpy() + data = {k: v} + new_data.set_data(data) + return new_data + + def to_tensor(self) -> "BaseDataElement": + """Convert all np.ndarray to tensor in data.""" + new_data = self.new() + for k, v in self.items(): + data = {} + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data[k] = v + elif isinstance(v, BaseDataElement): + v = v.to_tensor() + data[k] = v + new_data.set_data(data) + return new_data + + def to_dict(self) -> dict: + """Convert BaseDataElement to dict.""" + return { + k: v.to_dict() if isinstance(v, BaseDataElement) else v + for k, v in self.all_items() + } + + def __repr__(self) -> str: + """Represent the object.""" + + def _addindent(s_: str, num_spaces: int) -> str: + """This func is modified from `pytorch` https://github.com/pytorch/ + pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu + les/module.py#L29. + + Args: + s_ (str): The string to add spaces. + num_spaces (int): The num of space to add. + + Returns: + str: The string after add indent. + """ + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) # type: ignore + s = first + "\n" + s # type: ignore + return s # type: ignore + + def dump(obj: Any) -> str: + """Represent the object. + + Args: + obj (Any): The obj to represent. + + Returns: + str: The represented str. + """ + _repr = "" + if isinstance(obj, dict): + for k, v in obj.items(): + _repr += f"\n{k}: {_addindent(dump(v), 4)}" + elif isinstance(obj, BaseDataElement): + _repr += "\n\n META INFORMATION" + metainfo_items = dict(obj.metainfo_items()) + _repr += _addindent(dump(metainfo_items), 4) + _repr += "\n\n DATA FIELDS" + items = dict(obj.items()) + _repr += _addindent(dump(items), 4) + classname = obj.__class__.__name__ + _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" + else: + _repr += repr(obj) + return _repr + + return dump(self) diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py new file mode 100644 index 0000000..0feed6c --- /dev/null +++ b/abl/structures/list_data.py @@ -0,0 +1,301 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from collections.abc import Sized +from typing import Any, List, Union + +import numpy as np +import torch + +from .base_data_element import BaseDataElement + +BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] +LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] + +IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] + + +# Modified from +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa +class ListData(BaseDataElement): + """Data structure for instance-level annotations or predictions. + + Subclass of :class:`BaseDataElement`. All value in `data_fields` + should have the same length. This design refer to + https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 + ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value + in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, + and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. + + Examples: + >>> # custom data structure + >>> class TmpObject: + ... def __init__(self, tmp) -> None: + ... assert isinstance(tmp, list) + ... self.tmp = tmp + ... def __len__(self): + ... return len(self.tmp) + ... def __getitem__(self, item): + ... if isinstance(item, int): + ... if item >= len(self) or item < -len(self): # type:ignore + ... raise IndexError(f'Index {item} out of range!') + ... else: + ... # keep the dimension + ... item = slice(item, None, len(self)) + ... return TmpObject(self.tmp[item]) + ... @staticmethod + ... def cat(tmp_objs): + ... assert all(isinstance(results, TmpObject) for results in tmp_objs) + ... if len(tmp_objs) == 1: + ... return tmp_objs[0] + ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] + ... tmp_list = list(itertools.chain(*tmp_list)) + ... new_data = TmpObject(tmp_list) + ... return new_data + ... def __repr__(self): + ... return str(self.tmp) + >>> from mmengine.structures import ListData + >>> import numpy as np + >>> import torch + >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + >>> instance_data = ListData(metainfo=img_meta) + >>> 'img_shape' in instance_data + True + >>> instance_data.det_labels = torch.LongTensor([2, 3]) + >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) + >>> instance_data.bboxes = torch.rand((2, 4)) + >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) + >>> len(instance_data) + 2 + >>> print(instance_data) + + >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] + >>> sorted_results.det_scores + tensor([0.7000, 0.8000]) + >>> print(instance_data[instance_data.det_scores > 0.75]) + + >>> print(instance_data[instance_data.det_scores > 1]) + + >>> print(instance_data.cat([instance_data, instance_data])) + + """ + + def __setattr__(self, name: str, value: Sized): + """setattr is only used to set data. + + The value must have the attribute of `__len__` and have the same length + of `ListData`. + """ + if name in ("_metainfo_fields", "_data_fields"): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f"{name} has been used as a " + "private attribute, which is immutable." + ) + + else: + assert isinstance(value, Sized), "value must contain `__len__` attribute" + + if len(self) > 0: + assert len(value) == len(self), ( + "The length of " + f"values {len(value)} is " + "not consistent with " + "the length of this " + ":obj:`ListData` " + f"{len(self)}" + ) + super().__setattr__(name, value) + + __setitem__ = __setattr__ + + def __getitem__(self, item: IndexType) -> "ListData": + """ + Args: + item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, + :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): + Get the corresponding values according to item. + + Returns: + :obj:`ListData`: Corresponding values. + """ + assert isinstance(item, IndexType.__args__) + if isinstance(item, list): + item = np.array(item) + if isinstance(item, np.ndarray): + # The default int type of numpy is platform dependent, int32 for + # windows and int64 for linux. `torch.Tensor` requires the index + # should be int64, therefore we simply convert it to int64 here. + # More details in https://github.com/numpy/numpy/issues/9464 + item = item.astype(np.int64) if item.dtype == np.int32 else item + item = torch.from_numpy(item) + + if isinstance(item, str): + return getattr(self, item) + + if isinstance(item, int): + if item >= len(self) or item < -len(self): # type:ignore + raise IndexError(f"Index {item} out of range!") + else: + # keep the dimension + item = slice(item, None, len(self)) + + new_data = self.__class__(metainfo=self.metainfo) + if isinstance(item, torch.Tensor): + assert item.dim() == 1, ( + "Only support to get the" " values along the first dimension." + ) + if isinstance(item, BoolTypeTensor.__args__): + assert len(item) == len(self), ( + "The shape of the " + "input(BoolTensor) " + f"{len(item)} " + "does not match the shape " + "of the indexed tensor " + "in results_field " + f"{len(self)} at " + "first dimension." + ) + + for k, v in self.items(): + if isinstance(v, torch.Tensor): + new_data[k] = v[item] + elif isinstance(v, np.ndarray): + new_data[k] = v[item.cpu().numpy()] + elif isinstance(v, (str, list, tuple)) or ( + hasattr(v, "__getitem__") and hasattr(v, "cat") + ): + # convert to indexes from BoolTensor + if isinstance(item, BoolTypeTensor.__args__): + indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() + else: + indexes = item.cpu().numpy().tolist() + slice_list = [] + if indexes: + for index in indexes: + slice_list.append(slice(index, None, len(v))) + else: + slice_list.append(slice(None, 0, None)) + r_list = [v[s] for s in slice_list] + if isinstance(v, (str, list, tuple)): + new_value = r_list[0] + for r in r_list[1:]: + new_value = new_value + r + else: + new_value = v.cat(r_list) + new_data[k] = new_value + else: + raise ValueError( + f"The type of `{k}` is `{type(v)}`, which has no " + "attribute of `cat`, so it does not " + "support slice with `bool`" + ) + + else: + # item is a slice + for k, v in self.items(): + new_data[k] = v[item] + return new_data # type:ignore + + @staticmethod + def cat(instances_list: List["ListData"]) -> "ListData": + """Concat the instances of all :obj:`ListData` in the list. + + Note: To ensure that cat returns as expected, make sure that + all elements in the list must have exactly the same keys. + + Args: + instances_list (list[:obj:`ListData`]): A list + of :obj:`ListData`. + + Returns: + :obj:`ListData` + """ + assert all(isinstance(results, ListData) for results in instances_list) + assert len(instances_list) > 0 + if len(instances_list) == 1: + return instances_list[0] + + # metainfo and data_fields must be exactly the + # same for each element to avoid exceptions. + field_keys_list = [instances.all_keys() for instances in instances_list] + assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( + set(itertools.chain(*field_keys_list)) + ) == len(field_keys_list[0]), ( + "There are different keys in " + "`instances_list`, which may " + "cause the cat operation " + "to fail. Please make sure all " + "elements in `instances_list` " + "have the exact same key." + ) + + new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) + for k in instances_list[0].keys(): + values = [results[k] for results in instances_list] + v0 = values[0] + if isinstance(v0, torch.Tensor): + new_values = torch.cat(values, dim=0) + elif isinstance(v0, np.ndarray): + new_values = np.concatenate(values, axis=0) + elif isinstance(v0, (str, list, tuple)): + new_values = v0[:] + for v in values[1:]: + new_values += v + elif hasattr(v0, "cat"): + new_values = v0.cat(values) + else: + raise ValueError( + f"The type of `{k}` is `{type(v0)}` which has no " + "attribute of `cat`" + ) + new_data[k] = new_values + return new_data # type:ignore + + def __len__(self) -> int: + """int: The length of ListData.""" + if len(self._data_fields) > 0: + return len(self.values()[0]) + else: + return 0