| @@ -1,52 +1,64 @@ | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Tuple | |||
| from typing import Any, List, Tuple, Optional, Union | |||
| from ..learning import ABLModel | |||
| from ..reasoning import ReasonerBase | |||
| from ..structures import ListData | |||
| class BaseBridge(metaclass=ABCMeta): | |||
| DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] | |||
| class BaseBridge(metaclass=ABCMeta): | |||
| def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: | |||
| if not isinstance(model, ABLModel): | |||
| raise TypeError("Expected an ABLModel") | |||
| raise TypeError( | |||
| "Expected an instance of ABLModel, but received type: {}".format( | |||
| type(model) | |||
| ) | |||
| ) | |||
| if not isinstance(abducer, ReasonerBase): | |||
| raise TypeError("Expected an ReasonerBase") | |||
| raise TypeError( | |||
| "Expected an instance of ReasonerBase, but received type: {}".format( | |||
| type(abducer) | |||
| ) | |||
| ) | |||
| self.model = model | |||
| self.abducer = abducer | |||
| @abstractmethod | |||
| def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
| def predict( | |||
| self, data_samples: ListData | |||
| ) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
| """Placeholder for predict labels from input.""" | |||
| pass | |||
| @abstractmethod | |||
| def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||
| def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
| """Placeholder for abduce pseudo labels.""" | |||
| @abstractmethod | |||
| def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: | |||
| def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
| """Placeholder for map label space to symbol space.""" | |||
| pass | |||
| @abstractmethod | |||
| def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||
| def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||
| """Placeholder for map symbol space to label space.""" | |||
| pass | |||
| @abstractmethod | |||
| def train(self, train_data): | |||
| def train(self, train_data: Union[ListData, DataSet]): | |||
| """Placeholder for train loop of ABductive Learning.""" | |||
| pass | |||
| @abstractmethod | |||
| def test(self, test_data): | |||
| def valid(self, valid_data: Union[ListData, DataSet]) -> None: | |||
| """Placeholder for model test.""" | |||
| pass | |||
| @abstractmethod | |||
| def valid(self, valid_data): | |||
| def test(self, test_data: Union[ListData, DataSet]) -> None: | |||
| """Placeholder for model validation.""" | |||
| pass | |||
| @@ -1,12 +1,13 @@ | |||
| from ..learning import ABLModel | |||
| from ..reasoning import ReasonerBase | |||
| from ..evaluation import BaseMetric | |||
| from .base_bridge import BaseBridge | |||
| from typing import List, Union, Any, Tuple, Dict, Optional | |||
| from numpy import ndarray | |||
| from torch.utils.data import DataLoader | |||
| from ..dataset import BridgeDataset | |||
| from .base_bridge import BaseBridge, DataSet | |||
| from ..learning import ABLModel | |||
| from ..reasoning import ReasonerBase | |||
| from ..evaluation import BaseMetric | |||
| from ..structures import ListData | |||
| from ..utils.logger import print_log | |||
| @@ -20,64 +21,77 @@ class SimpleBridge(BaseBridge): | |||
| super().__init__(model, abducer) | |||
| self.metric_list = metric_list | |||
| def predict(self, X) -> Tuple[List[List[Any]], ndarray]: | |||
| pred_res = self.model.predict(X) | |||
| pred_idx, pred_prob = pred_res["label"], pred_res["prob"] | |||
| return pred_idx, pred_prob | |||
| def predict(self, data_samples: ListData) -> Tuple[List[ndarray], ndarray]: | |||
| pred_res = self.model.predict(data_samples) | |||
| data_samples.pred_idx = pred_res["label"] | |||
| data_samples.pred_prob = pred_res["prob"] | |||
| return data_samples["pred_idx"], ["data_samples.pred_prob"] | |||
| def abduce_pseudo_label( | |||
| self, | |||
| pred_prob: ndarray, | |||
| pred_pseudo_label: List[List[Any]], | |||
| Y: List[Any], | |||
| data_samples: ListData, | |||
| max_revision: int = -1, | |||
| require_more_revision: int = 0, | |||
| ) -> List[List[Any]]: | |||
| return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) | |||
| self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) | |||
| return data_samples["abduced_pseudo_label"] | |||
| def idx_to_pseudo_label( | |||
| self, idx: List[List[Any]], mapping: Dict = None | |||
| self, data_samples: ListData, mapping: Dict = None | |||
| ) -> List[List[Any]]: | |||
| if mapping is None: | |||
| mapping = self.abducer.mapping | |||
| return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] | |||
| pred_idx = data_samples.pred_idx | |||
| data_samples.pred_pseudo_label = [ | |||
| [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx | |||
| ] | |||
| return data_samples["pred_pseudo_label"] | |||
| def pseudo_label_to_idx( | |||
| self, pseudo_label: List[List[Any]], mapping: Dict = None | |||
| self, data_samples: ListData, mapping: Dict = None | |||
| ) -> List[List[Any]]: | |||
| if mapping is None: | |||
| mapping = self.abducer.remapping | |||
| return [ | |||
| [mapping[_pseudo_label] for _pseudo_label in sub_list] | |||
| for sub_list in pseudo_label | |||
| abduced_idx = [ | |||
| [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] | |||
| for sub_list in data_samples.abduced_pseudo_label | |||
| ] | |||
| data_samples.abduced_idx = abduced_idx | |||
| return data_samples["abduced_idx"] | |||
| def data_preprocess( | |||
| self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any] | |||
| ) -> ListData: | |||
| data_samples = ListData() | |||
| data_samples.X = X | |||
| data_samples.gt_pseudo_label = gt_pseudo_label | |||
| data_samples.Y = Y | |||
| return data_samples | |||
| def train( | |||
| self, | |||
| train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], | |||
| train_data: DataSet, | |||
| epochs: int = 50, | |||
| batch_size: Union[int, float] = -1, | |||
| eval_interval: int = 1, | |||
| ): | |||
| dataset = BridgeDataset(*train_data) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=batch_size, | |||
| collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
| ) | |||
| data_samples = self.data_preprocess(*train_data) | |||
| for epoch in range(epochs): | |||
| for seg_idx, (X, Z, Y) in enumerate(data_loader): | |||
| pred_idx, pred_prob = self.predict(X) | |||
| pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||
| abduced_pseudo_label = self.abduce_pseudo_label( | |||
| pred_prob, pred_pseudo_label, Y | |||
| ) | |||
| abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) | |||
| loss = self.model.train(X, abduced_label) | |||
| for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||
| sub_data_samples = data_samples[ | |||
| seg_idx * batch_size : (seg_idx + 1) * batch_size | |||
| ] | |||
| self.predict(sub_data_samples) | |||
| self.idx_to_pseudo_label(sub_data_samples) | |||
| self.abduce_pseudo_label(sub_data_samples) | |||
| self.pseudo_label_to_idx(sub_data_samples) | |||
| loss = self.model.train(sub_data_samples) | |||
| print_log( | |||
| f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", | |||
| f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{(len(data_samples) - 1) // batch_size + 1}] model loss is {loss:.5f}", | |||
| logger="current", | |||
| ) | |||
| @@ -85,20 +99,19 @@ class SimpleBridge(BaseBridge): | |||
| print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") | |||
| self.valid(train_data) | |||
| def _valid(self, data_loader): | |||
| for X, Z, Y in data_loader: | |||
| pred_idx, pred_prob = self.predict(X) | |||
| pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||
| data_samples = dict( | |||
| pred_idx=pred_idx, | |||
| pred_prob=pred_prob, | |||
| pred_pseudo_label=pred_pseudo_label, | |||
| gt_pseudo_label=Z, | |||
| Y=Y, | |||
| logic_forward=self.abducer.kb.logic_forward, | |||
| def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: | |||
| for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||
| sub_data_samples = data_samples[ | |||
| seg_idx * batch_size : (seg_idx + 1) * batch_size | |||
| ] | |||
| self.predict(sub_data_samples) | |||
| self.idx_to_pseudo_label(sub_data_samples) | |||
| sub_data_samples.set_metainfo( | |||
| dict(logic_forward=self.abducer.kb.logic_forward) | |||
| ) | |||
| for metric in self.metric_list: | |||
| metric.process(data_samples) | |||
| metric.process(sub_data_samples) | |||
| res = dict() | |||
| for metric in self.metric_list: | |||
| @@ -108,14 +121,12 @@ class SimpleBridge(BaseBridge): | |||
| msg += k + f": {v:.3f} " | |||
| print_log(msg, logger="current") | |||
| def valid(self, valid_data, batch_size=1000): | |||
| dataset = BridgeDataset(*valid_data) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=batch_size, | |||
| collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
| ) | |||
| self._valid(data_loader) | |||
| def test(self, test_data, batch_size=1000): | |||
| self.valid(test_data, batch_size) | |||
| def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||
| if not isinstance(valid_data, ListData): | |||
| data_samples = self.data_preprocess(*valid_data) | |||
| else: | |||
| data_samples = valid_data | |||
| self._valid(data_samples, batch_size=batch_size) | |||
| def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||
| self.valid(test_data, batch_size=batch_size) | |||
| @@ -1,8 +1,6 @@ | |||
| from typing import Optional, Sequence | |||
| from .base_metric import BaseMetric | |||
| class ABLMetric(): | |||
| pass | |||
| class SemanticsMetric(BaseMetric): | |||
| def __init__(self, prefix: Optional[str] = None) -> None: | |||
| @@ -9,9 +9,13 @@ | |||
| # Description : | |||
| # | |||
| # ================================================================# | |||
| from typing import List, Any, Optional | |||
| import pickle | |||
| from ..structures import ListData | |||
| from ..utils import flatten, reform_idx | |||
| from typing import List, Any, Optional | |||
| class ABLModel: | |||
| @@ -55,7 +59,7 @@ class ABLModel: | |||
| "base_model should have fit, predict and score methods." | |||
| ) | |||
| def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: | |||
| def predict(self, data_samples: ListData, mapping: Optional[dict] = None) -> dict: | |||
| """ | |||
| Predict the labels and probabilities for the given data. | |||
| @@ -72,11 +76,11 @@ class ABLModel: | |||
| A dictionary containing the predicted labels and probabilities. | |||
| """ | |||
| model = self.classifier_list[0] | |||
| data_X = flatten(X) | |||
| data_X = flatten(data_samples["X"]) | |||
| if hasattr(model, "predict_proba"): | |||
| prob = model.predict_proba(X=data_X) | |||
| label = prob.argmax(axis=1) | |||
| prob = reform_idx(prob, X) | |||
| prob = reform_idx(prob, data_samples["X"]) | |||
| else: | |||
| prob = None | |||
| label = model.predict(X=data_X) | |||
| @@ -84,7 +88,7 @@ class ABLModel: | |||
| if mapping is not None: | |||
| label = [mapping[y] for y in label] | |||
| label = reform_idx(label, X) | |||
| label = reform_idx(label, data_samples["X"]) | |||
| return {"label": label, "prob": prob} | |||
| @@ -109,7 +113,7 @@ class ABLModel: | |||
| score = self.classifier_list[0].score(X=data_X, y=data_Y) | |||
| return score | |||
| def train(self, X: List[List[Any]], Y: List[Any]) -> float: | |||
| def train(self, data_samples: ListData) -> float: | |||
| """ | |||
| Train the model on the given data. | |||
| @@ -125,9 +129,9 @@ class ABLModel: | |||
| float | |||
| The loss value of the trained model. | |||
| """ | |||
| data_X = flatten(X) | |||
| data_Y = flatten(Y) | |||
| return self.classifier_list[0].fit(X=data_X, y=data_Y) | |||
| data_X = flatten(data_samples["X"]) | |||
| data_y = flatten(data_samples["abduced_idx"]) | |||
| return self.classifier_list[0].fit(X=data_X, y=data_y) | |||
| def _model_operation(self, operation: str, *args, **kwargs): | |||
| model = self.classifier_list[0] | |||
| @@ -5,13 +5,21 @@ import numpy as np | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||
| from ..utils.utils import ( | |||
| flatten, | |||
| reform_idx, | |||
| hamming_dist, | |||
| check_equal, | |||
| to_hashable, | |||
| hashable_to_list, | |||
| ) | |||
| from multiprocessing import Pool | |||
| from functools import lru_cache | |||
| import pyswip | |||
| class KBBase(ABC): | |||
| def __init__(self, pseudo_label_list, max_err=0, use_cache=True): | |||
| # TODO:添加一下类型检查,比如 | |||
| @@ -20,7 +28,7 @@ class KBBase(ABC): | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.max_err = max_err | |||
| self.use_cache = use_cache | |||
| self.use_cache = use_cache | |||
| @abstractmethod | |||
| def logic_forward(self, pseudo_labels): | |||
| @@ -28,10 +36,17 @@ class KBBase(ABC): | |||
| def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | |||
| if not self.use_cache: | |||
| return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||
| else: | |||
| return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision) | |||
| return self._abduce_by_search( | |||
| pred_res, y, max_revision_num, require_more_revision | |||
| ) | |||
| else: | |||
| return self._abduce_by_search_cache( | |||
| to_hashable(pred_res), | |||
| to_hashable(y), | |||
| max_revision_num, | |||
| require_more_revision, | |||
| ) | |||
| def revise_by_idx(self, pred_res, y, revision_idx): | |||
| candidates = [] | |||
| abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | |||
| @@ -52,10 +67,12 @@ class KBBase(ABC): | |||
| new_candidates.extend(candidates) | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): | |||
| def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): | |||
| candidates = [] | |||
| for revision_num in range(len(pred_res) + 1): | |||
| if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): | |||
| if revision_num == 0 and check_equal( | |||
| self.logic_forward(pred_res), y, self.max_err | |||
| ): | |||
| candidates.append(pred_res) | |||
| elif revision_num > 0: | |||
| candidates.extend(self._revision(revision_num, pred_res, y)) | |||
| @@ -65,18 +82,24 @@ class KBBase(ABC): | |||
| if revision_num >= max_revision_num: | |||
| return [] | |||
| for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): | |||
| for revision_num in range( | |||
| min_revision_num + 1, min_revision_num + require_more_revision + 1 | |||
| ): | |||
| if revision_num > max_revision_num: | |||
| return candidates | |||
| candidates.extend(self._revision(revision_num, pred_res, y)) | |||
| return candidates | |||
| @lru_cache(maxsize=None) | |||
| def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision): | |||
| def _abduce_by_search_cache( | |||
| self, pred_res, y, max_revision_num, require_more_revision | |||
| ): | |||
| pred_res = hashable_to_list(pred_res) | |||
| y = hashable_to_list(y) | |||
| return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||
| return self._abduce_by_search( | |||
| pred_res, y, max_revision_num, require_more_revision | |||
| ) | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| @@ -88,17 +111,18 @@ class KBBase(ABC): | |||
| return 0 | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class ground_KB(KBBase): | |||
| def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| self.GKB_len_list = GKB_len_list | |||
| self.base = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| @@ -114,6 +138,7 @@ class ground_KB(KBBase): | |||
| def _get_GKB(self): | |||
| X, Y = [], [] | |||
| for length in self.GKB_len_list: | |||
| print("Generating GKB of length %d" % length) | |||
| arg_list = [] | |||
| for pre_x in self.pseudo_label_list: | |||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
| @@ -126,21 +151,24 @@ class ground_KB(KBBase): | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| if Y and isinstance(Y[0], (int, float)): | |||
| if Y and isinstance(Y[0], (int, float)): | |||
| X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||
| return X, Y | |||
| def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | |||
| return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision) | |||
| def _find_candidate_GKB(self, pred_res, y): | |||
| def abduce_candidates(self, data_sample, max_revision_num, require_more_revision=0): | |||
| return self._abduce_by_GKB( | |||
| data_sample, max_revision_num, require_more_revision=require_more_revision | |||
| ) | |||
| def _find_candidate_GKB(self, cache_key, data_sample): | |||
| y = data_sample["Y"][0] | |||
| if self.max_err == 0: | |||
| return self.base[len(pred_res)][y] | |||
| return self.base[cache_key][y] | |||
| else: | |||
| potential_candidates = self.base[len(pred_res)] | |||
| potential_candidates = self.base[cache_key] | |||
| key_list = list(potential_candidates.keys()) | |||
| key_idx = bisect.bisect_left(key_list, y) | |||
| all_candidates = [] | |||
| for idx in range(key_idx - 1, -1, -1): | |||
| k = key_list[idx] | |||
| @@ -148,7 +176,7 @@ class ground_KB(KBBase): | |||
| all_candidates.extend(potential_candidates[k]) | |||
| else: | |||
| break | |||
| for idx in range(key_idx, len(key_list)): | |||
| k = key_list[idx] | |||
| if abs(k - y) <= self.max_err: | |||
| @@ -156,19 +184,20 @@ class ground_KB(KBBase): | |||
| else: | |||
| break | |||
| return all_candidates | |||
| def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): | |||
| if self.base == {} or len(pred_res) not in self.GKB_len_list: | |||
| def _abduce_by_GKB(self, data_sample, max_revision_num, require_more_revision=0): | |||
| cache_key = len(data_sample["pred_pseudo_label"][0]) | |||
| if self.base == {} or cache_key not in self.GKB_len_list: | |||
| return [] | |||
| all_candidates = self._find_candidate_GKB(pred_res, y) | |||
| all_candidates = self._find_candidate_GKB(cache_key, data_sample) | |||
| if len(all_candidates) == 0: | |||
| return [] | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_revision_num = np.min(cost_list) | |||
| cost_array = hamming_dist(data_sample["pred_pseudo_label"][0], all_candidates) | |||
| min_revision_num = np.min(cost_array) | |||
| revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||
| idxs = np.where(cost_list <= revision_num)[0] | |||
| idxs = np.where(cost_array <= revision_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates | |||
| @@ -180,33 +209,38 @@ class prolog_KB(KBBase): | |||
| self.prolog.consult(pl_file) | |||
| def logic_forward(self, pseudo_labels): | |||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] | |||
| if result == 'true': | |||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0][ | |||
| "Res" | |||
| ] | |||
| if result == "true": | |||
| return True | |||
| elif result == 'false': | |||
| elif result == "false": | |||
| return False | |||
| return result | |||
| def _revision_pred_res(self, pred_res, revision_idx): | |||
| import re | |||
| revision_pred_res = pred_res.copy() | |||
| revision_pred_res = flatten(revision_pred_res) | |||
| for idx in revision_idx: | |||
| revision_pred_res[idx] = 'P' + str(idx) | |||
| revision_pred_res[idx] = "P" + str(idx) | |||
| revision_pred_res = reform_idx(revision_pred_res, pred_res) | |||
| # TODO:不知道有没有更简洁的方法 | |||
| regex = r"'P\d+'" | |||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res)) | |||
| return re.sub( | |||
| regex, lambda x: x.group().replace("'", ""), str(revision_pred_res) | |||
| ) | |||
| def get_query_string(self, pred_res, y, revision_idx): | |||
| query_string = "logic_forward(" | |||
| query_string += self._revision_pred_res(pred_res, revision_idx) | |||
| key_is_none_flag = y is None or (type(y) == list and y[0] is None) | |||
| query_string += ",%s)." % y if not key_is_none_flag else ")." | |||
| return query_string | |||
| def revise_by_idx(self, pred_res, y, revision_idx): | |||
| candidates = [] | |||
| query_string = self.get_query_string(pred_res, y, revision_idx) | |||
| @@ -70,7 +70,7 @@ class ReasonerBase: | |||
| candidates = [[self.remapping[x] for x in c] for c in candidates] | |||
| return confidence_dist(pred_prob, candidates) | |||
| def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): | |||
| def _get_one_candidate(self, data_sample, candidates): | |||
| """ | |||
| Get one candidate. If multiple candidates exist, return the one with minimum cost. | |||
| @@ -94,7 +94,9 @@ class ReasonerBase: | |||
| elif len(candidates) == 1: | |||
| return candidates[0] | |||
| else: | |||
| cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) | |||
| cost_array = self._get_cost_list( | |||
| data_sample["pred_pseudo_label"][0], data_sample["pred_prob"][0], candidates | |||
| ) | |||
| candidate = candidates[np.argmin(cost_array)] | |||
| return candidate | |||
| @@ -188,9 +190,7 @@ class ReasonerBase: | |||
| """ | |||
| return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||
| def abduce( | |||
| self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| def abduce(self, data_sample, max_revision=-1, require_more_revision=0): | |||
| """ | |||
| Perform revision by abduction on the given data. | |||
| @@ -213,26 +213,24 @@ class ReasonerBase: | |||
| list | |||
| The abduced revisions. | |||
| """ | |||
| symbol_num = len(flatten(pred_pseudo_label)) | |||
| symbol_num = len(flatten(data_sample.pred_pseudo_label)) | |||
| max_revision_num = calculate_revision_num(max_revision, symbol_num) | |||
| if self.use_zoopt: | |||
| solution = self.zoopt_get_solution( | |||
| symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
| symbol_num, data_sample, max_revision_num | |||
| ) | |||
| revision_idx = np.where(solution != 0)[0] | |||
| candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||
| candidates = self.revise_by_idx(data_sample, revision_idx) | |||
| else: | |||
| candidates = self.kb.abduce_candidates( | |||
| pred_pseudo_label, y, max_revision_num, require_more_revision | |||
| data_sample, max_revision_num, require_more_revision | |||
| ) | |||
| candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) | |||
| candidate = self._get_one_candidate(data_sample, candidates) | |||
| return candidate | |||
| def batch_abduce( | |||
| self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| def batch_abduce(self, data_samples, max_revision=-1, require_more_revision=0): | |||
| """ | |||
| Perform abduction on the given data in batches. | |||
| @@ -255,14 +253,15 @@ class ReasonerBase: | |||
| list | |||
| The abduced revisions in batches. | |||
| """ | |||
| return [ | |||
| abduced_pseudo_label = [ | |||
| self.abduce( | |||
| _pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision | |||
| ) | |||
| for _pred_prob, _pred_pseudo_label, _Y in zip( | |||
| pred_prob, pred_pseudo_label, Y | |||
| data_sample, | |||
| max_revision=max_revision, | |||
| require_more_revision=require_more_revision, | |||
| ) | |||
| for data_sample in data_samples | |||
| ] | |||
| data_samples.abduced_pseudo_label = abduced_pseudo_label | |||
| # def _batch_abduce_helper(self, args): | |||
| # z, prob, y, max_revision, require_more_revision = args | |||
| @@ -281,43 +280,57 @@ class ReasonerBase: | |||
| ) | |||
| if __name__ == "__main__": | |||
| from kb import KBBase, ground_KB, prolog_KB | |||
| prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| prob1 = [ | |||
| [ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| ] | |||
| prob2 = [ | |||
| [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| ] | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), | |||
| use_cache=True): | |||
| def __init__(self, pseudo_label_list=list(range(10)), use_cache=True): | |||
| super().__init__(pseudo_label_list, use_cache=use_cache) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class add_ground_KB(ground_KB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), | |||
| GKB_len_list=[2]): | |||
| def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||
| super().__init__(pseudo_label_list, GKB_len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def test_add(reasoner): | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce( | |||
| prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| @@ -337,8 +350,9 @@ if __name__ == "__main__": | |||
| test_add(reasoner) | |||
| print("prolog_KB with add.pl:") | |||
| kb = prolog_KB(pseudo_label_list=list(range(10)), | |||
| pl_file="examples/mnist_add/datasets/add.pl") | |||
| kb = prolog_KB( | |||
| pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl" | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| test_add(reasoner) | |||
| @@ -351,14 +365,16 @@ if __name__ == "__main__": | |||
| test_add(reasoner) | |||
| print("add_KB with multiple inputs at once:") | |||
| multiple_prob = [[ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ], | |||
| [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ]] | |||
| multiple_prob = [ | |||
| [ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ], | |||
| [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ], | |||
| ] | |||
| kb = add_KB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| @@ -383,8 +399,21 @@ if __name__ == "__main__": | |||
| class HWF_KB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
| "+", "-", "times", "div"], | |||
| pseudo_label_list=[ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| "+", | |||
| "-", | |||
| "times", | |||
| "div", | |||
| ], | |||
| max_err=1e-3, | |||
| ): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| @@ -393,7 +422,17 @@ if __name__ == "__main__": | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
| if i % 2 == 0 and formula[i] not in [ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| ]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
| return False | |||
| @@ -406,12 +445,25 @@ if __name__ == "__main__": | |||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval("".join(formula)) | |||
| class HWF_ground_KB(ground_KB): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
| "+", "-", "times", "div"], | |||
| pseudo_label_list=[ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| "+", | |||
| "-", | |||
| "times", | |||
| "div", | |||
| ], | |||
| GKB_len_list=[1, 3, 5, 7], | |||
| max_err=1e-3, | |||
| ): | |||
| @@ -421,7 +473,17 @@ if __name__ == "__main__": | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
| if i % 2 == 0 and formula[i] not in [ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| ]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
| return False | |||
| @@ -434,7 +496,7 @@ if __name__ == "__main__": | |||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval("".join(formula)) | |||
| def test_hwf(reasoner): | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| @@ -461,7 +523,7 @@ if __name__ == "__main__": | |||
| ) | |||
| print(res) | |||
| print() | |||
| def test_hwf_multiple(reasoner, max_revisions): | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| @@ -512,10 +574,10 @@ if __name__ == "__main__": | |||
| print("HWF_KB with multiple inputs at once:") | |||
| kb = HWF_KB(max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf_multiple(reasoner, max_revisions=[1,3,3]) | |||
| test_hwf_multiple(reasoner, max_revisions=[1, 3, 3]) | |||
| print("max_revision is float") | |||
| test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) | |||
| test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9]) | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| @@ -0,0 +1,2 @@ | |||
| from .base_data_element import BaseDataElement | |||
| from .list_data import ListData | |||
| @@ -0,0 +1,629 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import copy | |||
| from typing import Any, Iterator, Optional, Tuple, Type, Union | |||
| import numpy as np | |||
| import torch | |||
| class BaseDataElement: | |||
| """A base data interface that supports Tensor-like and dict-like | |||
| operations. | |||
| A typical data elements refer to predicted results or ground truth labels | |||
| on a task, such as predicted bboxes, instance masks, semantic | |||
| segmentation masks, etc. Because groundtruth labels and predicted results | |||
| often have similar properties (for example, the predicted bboxes and the | |||
| groundtruth bboxes), MMEngine uses the same abstract data interface to | |||
| encapsulate predicted results and groundtruth labels, and it is recommended | |||
| to use different name conventions to distinguish them, such as using | |||
| ``gt_instances`` and ``pred_instances`` to distinguish between labels and | |||
| predicted results. Additionally, we distinguish data elements at instance | |||
| level, pixel level, and label level. Each of these types has its own | |||
| characteristics. Therefore, MMEngine defines the base class | |||
| ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and | |||
| ``LabelData`` inheriting from ``BaseDataElement`` to represent different | |||
| types of ground truth labels or predictions. | |||
| Another common data element is sample data. A sample data consists of input | |||
| data (such as an image) and its annotations and predictions. In general, | |||
| an image can have multiple types of annotations and/or predictions at the | |||
| same time (for example, both pixel-level semantic segmentation annotations | |||
| and instance-level detection bboxes annotations). All labels and | |||
| predictions of a training sample are often passed between Dataset, Model, | |||
| Visualizer, and Evaluator components. In order to simplify the interface | |||
| between components, we can treat them as a large data element and | |||
| encapsulate them. Such data elements are generally called XXDataSample in | |||
| the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |||
| allows `BaseDataElement` as its attribute. Such a class generally | |||
| encapsulates all the data of a sample in the algorithm library, and its | |||
| attributes generally are various types of data elements. For example, | |||
| MMDetection is assigned by the BaseDataElement to encapsulate all the data | |||
| elements of the sample labeling and prediction of a sample in the | |||
| algorithm library. | |||
| The attributes in ``BaseDataElement`` are divided into two parts, | |||
| the ``metainfo`` and the ``data`` respectively. | |||
| - ``metainfo``: Usually contains the | |||
| information about the image such as filename, | |||
| image_shape, pad_shape, etc. The attributes can be accessed or | |||
| modified by dict-like or object-like operations, such as | |||
| ``.`` (for data access and modification), ``in``, ``del``, | |||
| ``pop(str)``, ``get(str)``, ``metainfo_keys()``, | |||
| ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for | |||
| set or change key-value pairs in metainfo). | |||
| - ``data``: Annotations or model predictions are | |||
| stored. The attributes can be accessed or modified by | |||
| dict-like or object-like operations, such as | |||
| ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |||
| ``values()``, ``items()``. Users can also apply tensor-like | |||
| methods to all :obj:`torch.Tensor` in the ``data_fields``, | |||
| such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, | |||
| ``to_tensor()``, ``.detach()``. | |||
| Args: | |||
| metainfo (dict, optional): A dict contains the meta information | |||
| of single image, such as ``dict(img_shape=(512, 512, 3), | |||
| scale_factor=(1, 1, 1, 1))``. Defaults to None. | |||
| kwargs (dict, optional): A dict contains annotations of single image or | |||
| model predictions. Defaults to None. | |||
| Examples: | |||
| >>> import torch | |||
| >>> from mmengine.structures import BaseDataElement | |||
| >>> gt_instances = BaseDataElement() | |||
| >>> bboxes = torch.rand((5, 4)) | |||
| >>> scores = torch.rand((5,)) | |||
| >>> img_id = 0 | |||
| >>> img_shape = (800, 1333) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=img_id, img_shape=img_shape), | |||
| ... bboxes=bboxes, scores=scores) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) | |||
| >>> # new | |||
| >>> gt_instances1 = gt_instances.new( | |||
| ... metainfo=dict(img_id=1, img_shape=(640, 640)), | |||
| ... bboxes=torch.rand((5, 4)), | |||
| ... scores=torch.rand((5,))) | |||
| >>> gt_instances2 = gt_instances1.new() | |||
| >>> # add and process property | |||
| >>> gt_instances = BaseDataElement() | |||
| >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) | |||
| >>> assert 'img_shape' in gt_instances.metainfo_keys() | |||
| >>> assert 'img_shape' in gt_instances | |||
| >>> assert 'img_shape' not in gt_instances.keys() | |||
| >>> assert 'img_shape' in gt_instances.all_keys() | |||
| >>> print(gt_instances.img_shape) | |||
| (100, 100) | |||
| >>> gt_instances.scores = torch.rand((5,)) | |||
| >>> assert 'scores' in gt_instances.keys() | |||
| >>> assert 'scores' in gt_instances | |||
| >>> assert 'scores' in gt_instances.all_keys() | |||
| >>> assert 'scores' not in gt_instances.metainfo_keys() | |||
| >>> print(gt_instances.scores) | |||
| tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) | |||
| >>> gt_instances.bboxes = torch.rand((5, 4)) | |||
| >>> assert 'bboxes' in gt_instances.keys() | |||
| >>> assert 'bboxes' in gt_instances | |||
| >>> assert 'bboxes' in gt_instances.all_keys() | |||
| >>> assert 'bboxes' not in gt_instances.metainfo_keys() | |||
| >>> print(gt_instances.bboxes) | |||
| tensor([[0.0900, 0.0424, 0.1755, 0.4469], | |||
| [0.8648, 0.0592, 0.3484, 0.0913], | |||
| [0.5808, 0.1909, 0.6165, 0.7088], | |||
| [0.5490, 0.4209, 0.9416, 0.2374], | |||
| [0.3652, 0.1218, 0.8805, 0.7523]]) | |||
| >>> # delete and change property | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=0, img_shape=(640, 640)), | |||
| ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) | |||
| >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) | |||
| >>> gt_instances.img_shape # (1280, 1280) | |||
| >>> gt_instances.bboxes = gt_instances.bboxes * 2 | |||
| >>> gt_instances.get('img_shape', None) # (1280, 1280) | |||
| >>> gt_instances.get('bboxes', None) # 6x4 tensor | |||
| >>> del gt_instances.img_shape | |||
| >>> del gt_instances.bboxes | |||
| >>> assert 'img_shape' not in gt_instances | |||
| >>> assert 'bboxes' not in gt_instances | |||
| >>> gt_instances.pop('img_shape', None) # None | |||
| >>> gt_instances.pop('bboxes', None) # None | |||
| >>> # Tensor-like | |||
| >>> cuda_instances = gt_instances.cuda() | |||
| >>> cuda_instances = gt_instances.to('cuda:0') | |||
| >>> cpu_instances = cuda_instances.cpu() | |||
| >>> cpu_instances = cuda_instances.to('cpu') | |||
| >>> fp16_instances = cuda_instances.to( | |||
| ... device=None, dtype=torch.float16, non_blocking=False, | |||
| ... copy=False, memory_format=torch.preserve_format) | |||
| >>> cpu_instances = cuda_instances.detach() | |||
| >>> np_instances = cpu_instances.numpy() | |||
| >>> metainfo = dict(img_shape=(800, 1196, 3)) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |||
| >>> sample = BaseDataElement(metainfo=metainfo, | |||
| ... gt_instances=gt_instances) | |||
| >>> print(sample) | |||
| <BaseDataElement( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| DATA FIELDS | |||
| gt_instances: <BaseDataElement( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([0, 1, 2, 3]) | |||
| ) at 0x7f0ec5eadc70> | |||
| ) at 0x7f0fea49e130> | |||
| >>> # inheritance | |||
| >>> class DetDataSample(BaseDataElement): | |||
| ... @property | |||
| ... def proposals(self): | |||
| ... return self._proposals | |||
| ... @proposals.setter | |||
| ... def proposals(self, value): | |||
| ... self.set_field(value, '_proposals', dtype=BaseDataElement) | |||
| ... @proposals.deleter | |||
| ... def proposals(self): | |||
| ... del self._proposals | |||
| ... @property | |||
| ... def gt_instances(self): | |||
| ... return self._gt_instances | |||
| ... @gt_instances.setter | |||
| ... def gt_instances(self, value): | |||
| ... self.set_field(value, '_gt_instances', | |||
| ... dtype=BaseDataElement) | |||
| ... @gt_instances.deleter | |||
| ... def gt_instances(self): | |||
| ... del self._gt_instances | |||
| ... @property | |||
| ... def pred_instances(self): | |||
| ... return self._pred_instances | |||
| ... @pred_instances.setter | |||
| ... def pred_instances(self, value): | |||
| ... self.set_field(value, '_pred_instances', | |||
| ... dtype=BaseDataElement) | |||
| ... @pred_instances.deleter | |||
| ... def pred_instances(self): | |||
| ... del self._pred_instances | |||
| >>> det_sample = DetDataSample() | |||
| >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |||
| >>> det_sample.proposals = proposals | |||
| >>> assert 'proposals' in det_sample | |||
| >>> assert det_sample.proposals == proposals | |||
| >>> del det_sample.proposals | |||
| >>> assert 'proposals' not in det_sample | |||
| >>> with self.assertRaises(AssertionError): | |||
| ... det_sample.proposals = torch.rand((5, 4)) | |||
| """ | |||
| def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |||
| self._metainfo_fields: set = set() | |||
| self._data_fields: set = set() | |||
| if metainfo is not None: | |||
| self.set_metainfo(metainfo=metainfo) | |||
| if kwargs: | |||
| self.set_data(kwargs) | |||
| def set_metainfo(self, metainfo: dict) -> None: | |||
| """Set or change key-value pairs in ``metainfo_field`` by parameter | |||
| ``metainfo``. | |||
| Args: | |||
| metainfo (dict): A dict contains the meta information | |||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||
| """ | |||
| assert isinstance( | |||
| metainfo, dict | |||
| ), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||
| # meta = copy.deepcopy(metainfo) | |||
| for k, v in metainfo.items(): | |||
| self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | |||
| def set_data(self, data: dict) -> None: | |||
| """Set or change key-value pairs in ``data_field`` by parameter | |||
| ``data``. | |||
| Args: | |||
| data (dict): A dict contains annotations of image or | |||
| model predictions. | |||
| """ | |||
| assert isinstance(data, dict), f"data should be a `dict` but got {data}" | |||
| for k, v in data.items(): | |||
| # Use `setattr()` rather than `self.set_field` to allow `set_data` | |||
| # to set property method. | |||
| setattr(self, k, v) | |||
| def update(self, instance: "BaseDataElement") -> None: | |||
| """The update() method updates the BaseDataElement with the elements | |||
| from another BaseDataElement object. | |||
| Args: | |||
| instance (BaseDataElement): Another BaseDataElement object for | |||
| update the current object. | |||
| """ | |||
| assert isinstance( | |||
| instance, BaseDataElement | |||
| ), f"instance should be a `BaseDataElement` but got {type(instance)}" | |||
| self.set_metainfo(dict(instance.metainfo_items())) | |||
| self.set_data(dict(instance.items())) | |||
| def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": | |||
| """Return a new data element with same type. If ``metainfo`` and | |||
| ``data`` are None, the new data element will have same metainfo and | |||
| data. If metainfo or data is not None, the new result will overwrite it | |||
| with the input value. | |||
| Args: | |||
| metainfo (dict, optional): A dict contains the meta information | |||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||
| Defaults to None. | |||
| kwargs (dict): A dict contains annotations of image or | |||
| model predictions. | |||
| Returns: | |||
| BaseDataElement: A new data element with same type. | |||
| """ | |||
| new_data = self.__class__() | |||
| if metainfo is not None: | |||
| new_data.set_metainfo(metainfo) | |||
| else: | |||
| new_data.set_metainfo(dict(self.metainfo_items())) | |||
| if kwargs: | |||
| new_data.set_data(kwargs) | |||
| else: | |||
| new_data.set_data(dict(self.items())) | |||
| return new_data | |||
| def clone(self): | |||
| """Deep copy the current data element. | |||
| Returns: | |||
| BaseDataElement: The copy of current data element. | |||
| """ | |||
| clone_data = self.__class__() | |||
| clone_data.set_metainfo(dict(self.metainfo_items())) | |||
| clone_data.set_data(dict(self.items())) | |||
| return clone_data | |||
| def keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in data_fields. | |||
| """ | |||
| # We assume that the name of the attribute related to property is | |||
| # '_' + the name of the property. We use this rule to filter out | |||
| # private keys. | |||
| # TODO: Use a more robust way to solve this problem | |||
| private_keys = { | |||
| "_" + key | |||
| for key in self._data_fields | |||
| if isinstance(getattr(type(self), key, None), property) | |||
| } | |||
| return list(self._data_fields - private_keys) | |||
| def metainfo_keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in metainfo_fields. | |||
| """ | |||
| return list(self._metainfo_fields) | |||
| def values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in data. | |||
| """ | |||
| return [getattr(self, k) for k in self.keys()] | |||
| def metainfo_values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in metainfo. | |||
| """ | |||
| return [getattr(self, k) for k in self.metainfo_keys()] | |||
| def all_keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in metainfo and data. | |||
| """ | |||
| return self.metainfo_keys() + self.keys() | |||
| def all_values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in metainfo and data. | |||
| """ | |||
| return self.metainfo_values() + self.values() | |||
| def all_items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``metainfo`` and ``data``. | |||
| """ | |||
| for k in self.all_keys(): | |||
| yield (k, getattr(self, k)) | |||
| def items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``data``. | |||
| """ | |||
| for k in self.keys(): | |||
| yield (k, getattr(self, k)) | |||
| def metainfo_items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``metainfo``. | |||
| """ | |||
| for k in self.metainfo_keys(): | |||
| yield (k, getattr(self, k)) | |||
| @property | |||
| def metainfo(self) -> dict: | |||
| """dict: A dict contains metainfo of current data element.""" | |||
| return dict(self.metainfo_items()) | |||
| def __setattr__(self, name: str, value: Any): | |||
| """setattr is only used to set data.""" | |||
| if name in ("_metainfo_fields", "_data_fields"): | |||
| if not hasattr(self, name): | |||
| super().__setattr__(name, value) | |||
| else: | |||
| raise AttributeError( | |||
| f"{name} has been used as a " | |||
| "private attribute, which is immutable." | |||
| ) | |||
| else: | |||
| self.set_field(name=name, value=value, field_type="data", dtype=None) | |||
| def __delattr__(self, item: str): | |||
| """Delete the item in dataelement. | |||
| Args: | |||
| item (str): The key to delete. | |||
| """ | |||
| if item in ("_metainfo_fields", "_data_fields"): | |||
| raise AttributeError( | |||
| f"{item} has been used as a " "private attribute, which is immutable." | |||
| ) | |||
| super().__delattr__(item) | |||
| if item in self._metainfo_fields: | |||
| self._metainfo_fields.remove(item) | |||
| elif item in self._data_fields: | |||
| self._data_fields.remove(item) | |||
| # dict-like methods | |||
| __delitem__ = __delattr__ | |||
| def get(self, key, default=None) -> Any: | |||
| """Get property in data and metainfo as the same as python.""" | |||
| # Use `getattr()` rather than `self.__dict__.get()` to allow getting | |||
| # properties. | |||
| return getattr(self, key, default) | |||
| def pop(self, *args) -> Any: | |||
| """Pop property in data and metainfo as the same as python.""" | |||
| assert len(args) < 3, "``pop`` get more than 2 arguments" | |||
| name = args[0] | |||
| if name in self._metainfo_fields: | |||
| self._metainfo_fields.remove(args[0]) | |||
| return self.__dict__.pop(*args) | |||
| elif name in self._data_fields: | |||
| self._data_fields.remove(args[0]) | |||
| return self.__dict__.pop(*args) | |||
| # with default value | |||
| elif len(args) == 2: | |||
| return args[1] | |||
| else: | |||
| # don't just use 'self.__dict__.pop(*args)' for only popping key in | |||
| # metainfo or data | |||
| raise KeyError(f"{args[0]} is not contained in metainfo or data") | |||
| def __contains__(self, item: str) -> bool: | |||
| """Whether the item is in dataelement. | |||
| Args: | |||
| item (str): The key to inquire. | |||
| """ | |||
| return item in self._data_fields or item in self._metainfo_fields | |||
| def set_field( | |||
| self, | |||
| value: Any, | |||
| name: str, | |||
| dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |||
| field_type: str = "data", | |||
| ) -> None: | |||
| """Special method for set union field, used as property.setter | |||
| functions.""" | |||
| assert field_type in ["metainfo", "data"] | |||
| if dtype is not None: | |||
| assert isinstance( | |||
| value, dtype | |||
| ), f"{value} should be a {dtype} but got {type(value)}" | |||
| if field_type == "metainfo": | |||
| if name in self._data_fields: | |||
| raise AttributeError( | |||
| f"Cannot set {name} to be a field of metainfo " | |||
| f"because {name} is already a data field" | |||
| ) | |||
| self._metainfo_fields.add(name) | |||
| else: | |||
| if name in self._metainfo_fields: | |||
| raise AttributeError( | |||
| f"Cannot set {name} to be a field of data " | |||
| f"because {name} is already a metainfo field" | |||
| ) | |||
| self._data_fields.add(name) | |||
| super().__setattr__(name, value) | |||
| # Tensor-like methods | |||
| def to(self, *args, **kwargs) -> "BaseDataElement": | |||
| """Apply same name function to all tensors in data_fields.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if hasattr(v, "to"): | |||
| v = v.to(*args, **kwargs) | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def cpu(self) -> "BaseDataElement": | |||
| """Convert all tensors to CPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.cpu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def cuda(self) -> "BaseDataElement": | |||
| """Convert all tensors to GPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.cuda() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def npu(self) -> "BaseDataElement": | |||
| """Convert all tensors to NPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.npu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def mlu(self) -> "BaseDataElement": | |||
| """Convert all tensors to MLU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.mlu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def detach(self) -> "BaseDataElement": | |||
| """Detach all tensors in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.detach() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def numpy(self) -> "BaseDataElement": | |||
| """Convert all tensors to np.ndarray in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.detach().cpu().numpy() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def to_tensor(self) -> "BaseDataElement": | |||
| """Convert all np.ndarray to tensor in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| data = {} | |||
| if isinstance(v, np.ndarray): | |||
| v = torch.from_numpy(v) | |||
| data[k] = v | |||
| elif isinstance(v, BaseDataElement): | |||
| v = v.to_tensor() | |||
| data[k] = v | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def to_dict(self) -> dict: | |||
| """Convert BaseDataElement to dict.""" | |||
| return { | |||
| k: v.to_dict() if isinstance(v, BaseDataElement) else v | |||
| for k, v in self.all_items() | |||
| } | |||
| def __repr__(self) -> str: | |||
| """Represent the object.""" | |||
| def _addindent(s_: str, num_spaces: int) -> str: | |||
| """This func is modified from `pytorch` https://github.com/pytorch/ | |||
| pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu | |||
| les/module.py#L29. | |||
| Args: | |||
| s_ (str): The string to add spaces. | |||
| num_spaces (int): The num of space to add. | |||
| Returns: | |||
| str: The string after add indent. | |||
| """ | |||
| s = s_.split("\n") | |||
| # don't do anything for single-line stuff | |||
| if len(s) == 1: | |||
| return s_ | |||
| first = s.pop(0) | |||
| s = [(num_spaces * " ") + line for line in s] | |||
| s = "\n".join(s) # type: ignore | |||
| s = first + "\n" + s # type: ignore | |||
| return s # type: ignore | |||
| def dump(obj: Any) -> str: | |||
| """Represent the object. | |||
| Args: | |||
| obj (Any): The obj to represent. | |||
| Returns: | |||
| str: The represented str. | |||
| """ | |||
| _repr = "" | |||
| if isinstance(obj, dict): | |||
| for k, v in obj.items(): | |||
| _repr += f"\n{k}: {_addindent(dump(v), 4)}" | |||
| elif isinstance(obj, BaseDataElement): | |||
| _repr += "\n\n META INFORMATION" | |||
| metainfo_items = dict(obj.metainfo_items()) | |||
| _repr += _addindent(dump(metainfo_items), 4) | |||
| _repr += "\n\n DATA FIELDS" | |||
| items = dict(obj.items()) | |||
| _repr += _addindent(dump(items), 4) | |||
| classname = obj.__class__.__name__ | |||
| _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" | |||
| else: | |||
| _repr += repr(obj) | |||
| return _repr | |||
| return dump(self) | |||
| @@ -0,0 +1,301 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import itertools | |||
| from collections.abc import Sized | |||
| from typing import Any, List, Union | |||
| import numpy as np | |||
| import torch | |||
| from .base_data_element import BaseDataElement | |||
| BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |||
| LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |||
| IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] | |||
| # Modified from | |||
| # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||
| class ListData(BaseDataElement): | |||
| """Data structure for instance-level annotations or predictions. | |||
| Subclass of :class:`BaseDataElement`. All value in `data_fields` | |||
| should have the same length. This design refer to | |||
| https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |||
| ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value | |||
| in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, | |||
| and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. | |||
| Examples: | |||
| >>> # custom data structure | |||
| >>> class TmpObject: | |||
| ... def __init__(self, tmp) -> None: | |||
| ... assert isinstance(tmp, list) | |||
| ... self.tmp = tmp | |||
| ... def __len__(self): | |||
| ... return len(self.tmp) | |||
| ... def __getitem__(self, item): | |||
| ... if isinstance(item, int): | |||
| ... if item >= len(self) or item < -len(self): # type:ignore | |||
| ... raise IndexError(f'Index {item} out of range!') | |||
| ... else: | |||
| ... # keep the dimension | |||
| ... item = slice(item, None, len(self)) | |||
| ... return TmpObject(self.tmp[item]) | |||
| ... @staticmethod | |||
| ... def cat(tmp_objs): | |||
| ... assert all(isinstance(results, TmpObject) for results in tmp_objs) | |||
| ... if len(tmp_objs) == 1: | |||
| ... return tmp_objs[0] | |||
| ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] | |||
| ... tmp_list = list(itertools.chain(*tmp_list)) | |||
| ... new_data = TmpObject(tmp_list) | |||
| ... return new_data | |||
| ... def __repr__(self): | |||
| ... return str(self.tmp) | |||
| >>> from mmengine.structures import ListData | |||
| >>> import numpy as np | |||
| >>> import torch | |||
| >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |||
| >>> instance_data = ListData(metainfo=img_meta) | |||
| >>> 'img_shape' in instance_data | |||
| True | |||
| >>> instance_data.det_labels = torch.LongTensor([2, 3]) | |||
| >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) | |||
| >>> instance_data.bboxes = torch.rand((2, 4)) | |||
| >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) | |||
| >>> len(instance_data) | |||
| 2 | |||
| >>> print(instance_data) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([2, 3]) | |||
| det_scores: tensor([0.8000, 0.7000]) | |||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |||
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] | |||
| ) at 0x7fb492de6280> | |||
| >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] | |||
| >>> sorted_results.det_scores | |||
| tensor([0.7000, 0.8000]) | |||
| >>> print(instance_data[instance_data.det_scores > 0.75]) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([2]) | |||
| det_scores: tensor([0.8000]) | |||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) | |||
| polygons: [[1, 2, 3, 4]] | |||
| ) at 0x7f64ecf0ec40> | |||
| >>> print(instance_data[instance_data.det_scores > 1]) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([], dtype=torch.int64) | |||
| det_scores: tensor([]) | |||
| bboxes: tensor([], size=(0, 4)) | |||
| polygons: [] | |||
| ) at 0x7f660a6a7f70> | |||
| >>> print(instance_data.cat([instance_data, instance_data])) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([2, 3, 2, 3]) | |||
| det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) | |||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||
| [0.8101, 0.3105, 0.5123, 0.6263], | |||
| [0.4997, 0.7707, 0.0595, 0.4188], | |||
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |||
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] | |||
| ) at 0x7f203542feb0> | |||
| """ | |||
| def __setattr__(self, name: str, value: Sized): | |||
| """setattr is only used to set data. | |||
| The value must have the attribute of `__len__` and have the same length | |||
| of `ListData`. | |||
| """ | |||
| if name in ("_metainfo_fields", "_data_fields"): | |||
| if not hasattr(self, name): | |||
| super().__setattr__(name, value) | |||
| else: | |||
| raise AttributeError( | |||
| f"{name} has been used as a " | |||
| "private attribute, which is immutable." | |||
| ) | |||
| else: | |||
| assert isinstance(value, Sized), "value must contain `__len__` attribute" | |||
| if len(self) > 0: | |||
| assert len(value) == len(self), ( | |||
| "The length of " | |||
| f"values {len(value)} is " | |||
| "not consistent with " | |||
| "the length of this " | |||
| ":obj:`ListData` " | |||
| f"{len(self)}" | |||
| ) | |||
| super().__setattr__(name, value) | |||
| __setitem__ = __setattr__ | |||
| def __getitem__(self, item: IndexType) -> "ListData": | |||
| """ | |||
| Args: | |||
| item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |||
| :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |||
| Get the corresponding values according to item. | |||
| Returns: | |||
| :obj:`ListData`: Corresponding values. | |||
| """ | |||
| assert isinstance(item, IndexType.__args__) | |||
| if isinstance(item, list): | |||
| item = np.array(item) | |||
| if isinstance(item, np.ndarray): | |||
| # The default int type of numpy is platform dependent, int32 for | |||
| # windows and int64 for linux. `torch.Tensor` requires the index | |||
| # should be int64, therefore we simply convert it to int64 here. | |||
| # More details in https://github.com/numpy/numpy/issues/9464 | |||
| item = item.astype(np.int64) if item.dtype == np.int32 else item | |||
| item = torch.from_numpy(item) | |||
| if isinstance(item, str): | |||
| return getattr(self, item) | |||
| if isinstance(item, int): | |||
| if item >= len(self) or item < -len(self): # type:ignore | |||
| raise IndexError(f"Index {item} out of range!") | |||
| else: | |||
| # keep the dimension | |||
| item = slice(item, None, len(self)) | |||
| new_data = self.__class__(metainfo=self.metainfo) | |||
| if isinstance(item, torch.Tensor): | |||
| assert item.dim() == 1, ( | |||
| "Only support to get the" " values along the first dimension." | |||
| ) | |||
| if isinstance(item, BoolTypeTensor.__args__): | |||
| assert len(item) == len(self), ( | |||
| "The shape of the " | |||
| "input(BoolTensor) " | |||
| f"{len(item)} " | |||
| "does not match the shape " | |||
| "of the indexed tensor " | |||
| "in results_field " | |||
| f"{len(self)} at " | |||
| "first dimension." | |||
| ) | |||
| for k, v in self.items(): | |||
| if isinstance(v, torch.Tensor): | |||
| new_data[k] = v[item] | |||
| elif isinstance(v, np.ndarray): | |||
| new_data[k] = v[item.cpu().numpy()] | |||
| elif isinstance(v, (str, list, tuple)) or ( | |||
| hasattr(v, "__getitem__") and hasattr(v, "cat") | |||
| ): | |||
| # convert to indexes from BoolTensor | |||
| if isinstance(item, BoolTypeTensor.__args__): | |||
| indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() | |||
| else: | |||
| indexes = item.cpu().numpy().tolist() | |||
| slice_list = [] | |||
| if indexes: | |||
| for index in indexes: | |||
| slice_list.append(slice(index, None, len(v))) | |||
| else: | |||
| slice_list.append(slice(None, 0, None)) | |||
| r_list = [v[s] for s in slice_list] | |||
| if isinstance(v, (str, list, tuple)): | |||
| new_value = r_list[0] | |||
| for r in r_list[1:]: | |||
| new_value = new_value + r | |||
| else: | |||
| new_value = v.cat(r_list) | |||
| new_data[k] = new_value | |||
| else: | |||
| raise ValueError( | |||
| f"The type of `{k}` is `{type(v)}`, which has no " | |||
| "attribute of `cat`, so it does not " | |||
| "support slice with `bool`" | |||
| ) | |||
| else: | |||
| # item is a slice | |||
| for k, v in self.items(): | |||
| new_data[k] = v[item] | |||
| return new_data # type:ignore | |||
| @staticmethod | |||
| def cat(instances_list: List["ListData"]) -> "ListData": | |||
| """Concat the instances of all :obj:`ListData` in the list. | |||
| Note: To ensure that cat returns as expected, make sure that | |||
| all elements in the list must have exactly the same keys. | |||
| Args: | |||
| instances_list (list[:obj:`ListData`]): A list | |||
| of :obj:`ListData`. | |||
| Returns: | |||
| :obj:`ListData` | |||
| """ | |||
| assert all(isinstance(results, ListData) for results in instances_list) | |||
| assert len(instances_list) > 0 | |||
| if len(instances_list) == 1: | |||
| return instances_list[0] | |||
| # metainfo and data_fields must be exactly the | |||
| # same for each element to avoid exceptions. | |||
| field_keys_list = [instances.all_keys() for instances in instances_list] | |||
| assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( | |||
| set(itertools.chain(*field_keys_list)) | |||
| ) == len(field_keys_list[0]), ( | |||
| "There are different keys in " | |||
| "`instances_list`, which may " | |||
| "cause the cat operation " | |||
| "to fail. Please make sure all " | |||
| "elements in `instances_list` " | |||
| "have the exact same key." | |||
| ) | |||
| new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) | |||
| for k in instances_list[0].keys(): | |||
| values = [results[k] for results in instances_list] | |||
| v0 = values[0] | |||
| if isinstance(v0, torch.Tensor): | |||
| new_values = torch.cat(values, dim=0) | |||
| elif isinstance(v0, np.ndarray): | |||
| new_values = np.concatenate(values, axis=0) | |||
| elif isinstance(v0, (str, list, tuple)): | |||
| new_values = v0[:] | |||
| for v in values[1:]: | |||
| new_values += v | |||
| elif hasattr(v0, "cat"): | |||
| new_values = v0.cat(values) | |||
| else: | |||
| raise ValueError( | |||
| f"The type of `{k}` is `{type(v0)}` which has no " | |||
| "attribute of `cat`" | |||
| ) | |||
| new_data[k] = new_values | |||
| return new_data # type:ignore | |||
| def __len__(self) -> int: | |||
| """int: The length of ListData.""" | |||
| if len(self._data_fields) > 0: | |||
| return len(self.values()[0]) | |||
| else: | |||
| return 0 | |||