| @@ -1,52 +1,64 @@ | |||||
| from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
| from typing import Any, List, Tuple | |||||
| from typing import Any, List, Tuple, Optional, Union | |||||
| from ..learning import ABLModel | from ..learning import ABLModel | ||||
| from ..reasoning import ReasonerBase | from ..reasoning import ReasonerBase | ||||
| from ..structures import ListData | |||||
| class BaseBridge(metaclass=ABCMeta): | |||||
| DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] | |||||
| class BaseBridge(metaclass=ABCMeta): | |||||
| def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: | def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: | ||||
| if not isinstance(model, ABLModel): | if not isinstance(model, ABLModel): | ||||
| raise TypeError("Expected an ABLModel") | |||||
| raise TypeError( | |||||
| "Expected an instance of ABLModel, but received type: {}".format( | |||||
| type(model) | |||||
| ) | |||||
| ) | |||||
| if not isinstance(abducer, ReasonerBase): | if not isinstance(abducer, ReasonerBase): | ||||
| raise TypeError("Expected an ReasonerBase") | |||||
| raise TypeError( | |||||
| "Expected an instance of ReasonerBase, but received type: {}".format( | |||||
| type(abducer) | |||||
| ) | |||||
| ) | |||||
| self.model = model | self.model = model | ||||
| self.abducer = abducer | self.abducer = abducer | ||||
| @abstractmethod | @abstractmethod | ||||
| def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: | |||||
| def predict( | |||||
| self, data_samples: ListData | |||||
| ) -> Tuple[List[List[Any]], List[List[Any]]]: | |||||
| """Placeholder for predict labels from input.""" | """Placeholder for predict labels from input.""" | ||||
| pass | pass | ||||
| @abstractmethod | @abstractmethod | ||||
| def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||||
| def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||||
| """Placeholder for abduce pseudo labels.""" | """Placeholder for abduce pseudo labels.""" | ||||
| @abstractmethod | @abstractmethod | ||||
| def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: | |||||
| def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||||
| """Placeholder for map label space to symbol space.""" | """Placeholder for map label space to symbol space.""" | ||||
| pass | pass | ||||
| @abstractmethod | @abstractmethod | ||||
| def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||||
| def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||||
| """Placeholder for map symbol space to label space.""" | """Placeholder for map symbol space to label space.""" | ||||
| pass | pass | ||||
| @abstractmethod | @abstractmethod | ||||
| def train(self, train_data): | |||||
| def train(self, train_data: Union[ListData, DataSet]): | |||||
| """Placeholder for train loop of ABductive Learning.""" | """Placeholder for train loop of ABductive Learning.""" | ||||
| pass | pass | ||||
| @abstractmethod | @abstractmethod | ||||
| def test(self, test_data): | |||||
| def valid(self, valid_data: Union[ListData, DataSet]) -> None: | |||||
| """Placeholder for model test.""" | """Placeholder for model test.""" | ||||
| pass | pass | ||||
| @abstractmethod | @abstractmethod | ||||
| def valid(self, valid_data): | |||||
| def test(self, test_data: Union[ListData, DataSet]) -> None: | |||||
| """Placeholder for model validation.""" | """Placeholder for model validation.""" | ||||
| pass | pass | ||||
| @@ -1,12 +1,13 @@ | |||||
| from ..learning import ABLModel | |||||
| from ..reasoning import ReasonerBase | |||||
| from ..evaluation import BaseMetric | |||||
| from .base_bridge import BaseBridge | |||||
| from typing import List, Union, Any, Tuple, Dict, Optional | from typing import List, Union, Any, Tuple, Dict, Optional | ||||
| from numpy import ndarray | from numpy import ndarray | ||||
| from torch.utils.data import DataLoader | |||||
| from ..dataset import BridgeDataset | |||||
| from .base_bridge import BaseBridge, DataSet | |||||
| from ..learning import ABLModel | |||||
| from ..reasoning import ReasonerBase | |||||
| from ..evaluation import BaseMetric | |||||
| from ..structures import ListData | |||||
| from ..utils.logger import print_log | from ..utils.logger import print_log | ||||
| @@ -20,64 +21,77 @@ class SimpleBridge(BaseBridge): | |||||
| super().__init__(model, abducer) | super().__init__(model, abducer) | ||||
| self.metric_list = metric_list | self.metric_list = metric_list | ||||
| def predict(self, X) -> Tuple[List[List[Any]], ndarray]: | |||||
| pred_res = self.model.predict(X) | |||||
| pred_idx, pred_prob = pred_res["label"], pred_res["prob"] | |||||
| return pred_idx, pred_prob | |||||
| def predict(self, data_samples: ListData) -> Tuple[List[ndarray], ndarray]: | |||||
| pred_res = self.model.predict(data_samples) | |||||
| data_samples.pred_idx = pred_res["label"] | |||||
| data_samples.pred_prob = pred_res["prob"] | |||||
| return data_samples["pred_idx"], ["data_samples.pred_prob"] | |||||
| def abduce_pseudo_label( | def abduce_pseudo_label( | ||||
| self, | self, | ||||
| pred_prob: ndarray, | |||||
| pred_pseudo_label: List[List[Any]], | |||||
| Y: List[Any], | |||||
| data_samples: ListData, | |||||
| max_revision: int = -1, | max_revision: int = -1, | ||||
| require_more_revision: int = 0, | require_more_revision: int = 0, | ||||
| ) -> List[List[Any]]: | ) -> List[List[Any]]: | ||||
| return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) | |||||
| self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) | |||||
| return data_samples["abduced_pseudo_label"] | |||||
| def idx_to_pseudo_label( | def idx_to_pseudo_label( | ||||
| self, idx: List[List[Any]], mapping: Dict = None | |||||
| self, data_samples: ListData, mapping: Dict = None | |||||
| ) -> List[List[Any]]: | ) -> List[List[Any]]: | ||||
| if mapping is None: | if mapping is None: | ||||
| mapping = self.abducer.mapping | mapping = self.abducer.mapping | ||||
| return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] | |||||
| pred_idx = data_samples.pred_idx | |||||
| data_samples.pred_pseudo_label = [ | |||||
| [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx | |||||
| ] | |||||
| return data_samples["pred_pseudo_label"] | |||||
| def pseudo_label_to_idx( | def pseudo_label_to_idx( | ||||
| self, pseudo_label: List[List[Any]], mapping: Dict = None | |||||
| self, data_samples: ListData, mapping: Dict = None | |||||
| ) -> List[List[Any]]: | ) -> List[List[Any]]: | ||||
| if mapping is None: | if mapping is None: | ||||
| mapping = self.abducer.remapping | mapping = self.abducer.remapping | ||||
| return [ | |||||
| [mapping[_pseudo_label] for _pseudo_label in sub_list] | |||||
| for sub_list in pseudo_label | |||||
| abduced_idx = [ | |||||
| [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] | |||||
| for sub_list in data_samples.abduced_pseudo_label | |||||
| ] | ] | ||||
| data_samples.abduced_idx = abduced_idx | |||||
| return data_samples["abduced_idx"] | |||||
| def data_preprocess( | |||||
| self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any] | |||||
| ) -> ListData: | |||||
| data_samples = ListData() | |||||
| data_samples.X = X | |||||
| data_samples.gt_pseudo_label = gt_pseudo_label | |||||
| data_samples.Y = Y | |||||
| return data_samples | |||||
| def train( | def train( | ||||
| self, | self, | ||||
| train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], | |||||
| train_data: DataSet, | |||||
| epochs: int = 50, | epochs: int = 50, | ||||
| batch_size: Union[int, float] = -1, | batch_size: Union[int, float] = -1, | ||||
| eval_interval: int = 1, | eval_interval: int = 1, | ||||
| ): | ): | ||||
| dataset = BridgeDataset(*train_data) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size=batch_size, | |||||
| collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||||
| ) | |||||
| data_samples = self.data_preprocess(*train_data) | |||||
| for epoch in range(epochs): | for epoch in range(epochs): | ||||
| for seg_idx, (X, Z, Y) in enumerate(data_loader): | |||||
| pred_idx, pred_prob = self.predict(X) | |||||
| pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||||
| abduced_pseudo_label = self.abduce_pseudo_label( | |||||
| pred_prob, pred_pseudo_label, Y | |||||
| ) | |||||
| abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) | |||||
| loss = self.model.train(X, abduced_label) | |||||
| for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||||
| sub_data_samples = data_samples[ | |||||
| seg_idx * batch_size : (seg_idx + 1) * batch_size | |||||
| ] | |||||
| self.predict(sub_data_samples) | |||||
| self.idx_to_pseudo_label(sub_data_samples) | |||||
| self.abduce_pseudo_label(sub_data_samples) | |||||
| self.pseudo_label_to_idx(sub_data_samples) | |||||
| loss = self.model.train(sub_data_samples) | |||||
| print_log( | print_log( | ||||
| f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", | |||||
| f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{(len(data_samples) - 1) // batch_size + 1}] model loss is {loss:.5f}", | |||||
| logger="current", | logger="current", | ||||
| ) | ) | ||||
| @@ -85,20 +99,19 @@ class SimpleBridge(BaseBridge): | |||||
| print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") | print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") | ||||
| self.valid(train_data) | self.valid(train_data) | ||||
| def _valid(self, data_loader): | |||||
| for X, Z, Y in data_loader: | |||||
| pred_idx, pred_prob = self.predict(X) | |||||
| pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||||
| data_samples = dict( | |||||
| pred_idx=pred_idx, | |||||
| pred_prob=pred_prob, | |||||
| pred_pseudo_label=pred_pseudo_label, | |||||
| gt_pseudo_label=Z, | |||||
| Y=Y, | |||||
| logic_forward=self.abducer.kb.logic_forward, | |||||
| def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: | |||||
| for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||||
| sub_data_samples = data_samples[ | |||||
| seg_idx * batch_size : (seg_idx + 1) * batch_size | |||||
| ] | |||||
| self.predict(sub_data_samples) | |||||
| self.idx_to_pseudo_label(sub_data_samples) | |||||
| sub_data_samples.set_metainfo( | |||||
| dict(logic_forward=self.abducer.kb.logic_forward) | |||||
| ) | ) | ||||
| for metric in self.metric_list: | for metric in self.metric_list: | ||||
| metric.process(data_samples) | |||||
| metric.process(sub_data_samples) | |||||
| res = dict() | res = dict() | ||||
| for metric in self.metric_list: | for metric in self.metric_list: | ||||
| @@ -108,14 +121,12 @@ class SimpleBridge(BaseBridge): | |||||
| msg += k + f": {v:.3f} " | msg += k + f": {v:.3f} " | ||||
| print_log(msg, logger="current") | print_log(msg, logger="current") | ||||
| def valid(self, valid_data, batch_size=1000): | |||||
| dataset = BridgeDataset(*valid_data) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size=batch_size, | |||||
| collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||||
| ) | |||||
| self._valid(data_loader) | |||||
| def test(self, test_data, batch_size=1000): | |||||
| self.valid(test_data, batch_size) | |||||
| def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||||
| if not isinstance(valid_data, ListData): | |||||
| data_samples = self.data_preprocess(*valid_data) | |||||
| else: | |||||
| data_samples = valid_data | |||||
| self._valid(data_samples, batch_size=batch_size) | |||||
| def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||||
| self.valid(test_data, batch_size=batch_size) | |||||
| @@ -1,8 +1,6 @@ | |||||
| from typing import Optional, Sequence | from typing import Optional, Sequence | ||||
| from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
| class ABLMetric(): | |||||
| pass | |||||
| class SemanticsMetric(BaseMetric): | class SemanticsMetric(BaseMetric): | ||||
| def __init__(self, prefix: Optional[str] = None) -> None: | def __init__(self, prefix: Optional[str] = None) -> None: | ||||
| @@ -9,9 +9,13 @@ | |||||
| # Description : | # Description : | ||||
| # | # | ||||
| # ================================================================# | # ================================================================# | ||||
| from typing import List, Any, Optional | |||||
| import pickle | import pickle | ||||
| from ..structures import ListData | |||||
| from ..utils import flatten, reform_idx | from ..utils import flatten, reform_idx | ||||
| from typing import List, Any, Optional | |||||
| class ABLModel: | class ABLModel: | ||||
| @@ -55,7 +59,7 @@ class ABLModel: | |||||
| "base_model should have fit, predict and score methods." | "base_model should have fit, predict and score methods." | ||||
| ) | ) | ||||
| def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: | |||||
| def predict(self, data_samples: ListData, mapping: Optional[dict] = None) -> dict: | |||||
| """ | """ | ||||
| Predict the labels and probabilities for the given data. | Predict the labels and probabilities for the given data. | ||||
| @@ -72,11 +76,11 @@ class ABLModel: | |||||
| A dictionary containing the predicted labels and probabilities. | A dictionary containing the predicted labels and probabilities. | ||||
| """ | """ | ||||
| model = self.classifier_list[0] | model = self.classifier_list[0] | ||||
| data_X = flatten(X) | |||||
| data_X = flatten(data_samples["X"]) | |||||
| if hasattr(model, "predict_proba"): | if hasattr(model, "predict_proba"): | ||||
| prob = model.predict_proba(X=data_X) | prob = model.predict_proba(X=data_X) | ||||
| label = prob.argmax(axis=1) | label = prob.argmax(axis=1) | ||||
| prob = reform_idx(prob, X) | |||||
| prob = reform_idx(prob, data_samples["X"]) | |||||
| else: | else: | ||||
| prob = None | prob = None | ||||
| label = model.predict(X=data_X) | label = model.predict(X=data_X) | ||||
| @@ -84,7 +88,7 @@ class ABLModel: | |||||
| if mapping is not None: | if mapping is not None: | ||||
| label = [mapping[y] for y in label] | label = [mapping[y] for y in label] | ||||
| label = reform_idx(label, X) | |||||
| label = reform_idx(label, data_samples["X"]) | |||||
| return {"label": label, "prob": prob} | return {"label": label, "prob": prob} | ||||
| @@ -109,7 +113,7 @@ class ABLModel: | |||||
| score = self.classifier_list[0].score(X=data_X, y=data_Y) | score = self.classifier_list[0].score(X=data_X, y=data_Y) | ||||
| return score | return score | ||||
| def train(self, X: List[List[Any]], Y: List[Any]) -> float: | |||||
| def train(self, data_samples: ListData) -> float: | |||||
| """ | """ | ||||
| Train the model on the given data. | Train the model on the given data. | ||||
| @@ -125,9 +129,9 @@ class ABLModel: | |||||
| float | float | ||||
| The loss value of the trained model. | The loss value of the trained model. | ||||
| """ | """ | ||||
| data_X = flatten(X) | |||||
| data_Y = flatten(Y) | |||||
| return self.classifier_list[0].fit(X=data_X, y=data_Y) | |||||
| data_X = flatten(data_samples["X"]) | |||||
| data_y = flatten(data_samples["abduced_idx"]) | |||||
| return self.classifier_list[0].fit(X=data_X, y=data_y) | |||||
| def _model_operation(self, operation: str, *args, **kwargs): | def _model_operation(self, operation: str, *args, **kwargs): | ||||
| model = self.classifier_list[0] | model = self.classifier_list[0] | ||||
| @@ -5,13 +5,21 @@ import numpy as np | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from itertools import product, combinations | from itertools import product, combinations | ||||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||||
| from ..utils.utils import ( | |||||
| flatten, | |||||
| reform_idx, | |||||
| hamming_dist, | |||||
| check_equal, | |||||
| to_hashable, | |||||
| hashable_to_list, | |||||
| ) | |||||
| from multiprocessing import Pool | from multiprocessing import Pool | ||||
| from functools import lru_cache | from functools import lru_cache | ||||
| import pyswip | import pyswip | ||||
| class KBBase(ABC): | class KBBase(ABC): | ||||
| def __init__(self, pseudo_label_list, max_err=0, use_cache=True): | def __init__(self, pseudo_label_list, max_err=0, use_cache=True): | ||||
| # TODO:添加一下类型检查,比如 | # TODO:添加一下类型检查,比如 | ||||
| @@ -20,7 +28,7 @@ class KBBase(ABC): | |||||
| self.pseudo_label_list = pseudo_label_list | self.pseudo_label_list = pseudo_label_list | ||||
| self.max_err = max_err | self.max_err = max_err | ||||
| self.use_cache = use_cache | |||||
| self.use_cache = use_cache | |||||
| @abstractmethod | @abstractmethod | ||||
| def logic_forward(self, pseudo_labels): | def logic_forward(self, pseudo_labels): | ||||
| @@ -28,10 +36,17 @@ class KBBase(ABC): | |||||
| def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | ||||
| if not self.use_cache: | if not self.use_cache: | ||||
| return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||||
| else: | |||||
| return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision) | |||||
| return self._abduce_by_search( | |||||
| pred_res, y, max_revision_num, require_more_revision | |||||
| ) | |||||
| else: | |||||
| return self._abduce_by_search_cache( | |||||
| to_hashable(pred_res), | |||||
| to_hashable(y), | |||||
| max_revision_num, | |||||
| require_more_revision, | |||||
| ) | |||||
| def revise_by_idx(self, pred_res, y, revision_idx): | def revise_by_idx(self, pred_res, y, revision_idx): | ||||
| candidates = [] | candidates = [] | ||||
| abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | ||||
| @@ -52,10 +67,12 @@ class KBBase(ABC): | |||||
| new_candidates.extend(candidates) | new_candidates.extend(candidates) | ||||
| return new_candidates | return new_candidates | ||||
| def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): | |||||
| def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): | |||||
| candidates = [] | candidates = [] | ||||
| for revision_num in range(len(pred_res) + 1): | for revision_num in range(len(pred_res) + 1): | ||||
| if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): | |||||
| if revision_num == 0 and check_equal( | |||||
| self.logic_forward(pred_res), y, self.max_err | |||||
| ): | |||||
| candidates.append(pred_res) | candidates.append(pred_res) | ||||
| elif revision_num > 0: | elif revision_num > 0: | ||||
| candidates.extend(self._revision(revision_num, pred_res, y)) | candidates.extend(self._revision(revision_num, pred_res, y)) | ||||
| @@ -65,18 +82,24 @@ class KBBase(ABC): | |||||
| if revision_num >= max_revision_num: | if revision_num >= max_revision_num: | ||||
| return [] | return [] | ||||
| for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): | |||||
| for revision_num in range( | |||||
| min_revision_num + 1, min_revision_num + require_more_revision + 1 | |||||
| ): | |||||
| if revision_num > max_revision_num: | if revision_num > max_revision_num: | ||||
| return candidates | return candidates | ||||
| candidates.extend(self._revision(revision_num, pred_res, y)) | candidates.extend(self._revision(revision_num, pred_res, y)) | ||||
| return candidates | return candidates | ||||
| @lru_cache(maxsize=None) | @lru_cache(maxsize=None) | ||||
| def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision): | |||||
| def _abduce_by_search_cache( | |||||
| self, pred_res, y, max_revision_num, require_more_revision | |||||
| ): | |||||
| pred_res = hashable_to_list(pred_res) | pred_res = hashable_to_list(pred_res) | ||||
| y = hashable_to_list(y) | y = hashable_to_list(y) | ||||
| return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||||
| return self._abduce_by_search( | |||||
| pred_res, y, max_revision_num, require_more_revision | |||||
| ) | |||||
| def _dict_len(self, dic): | def _dict_len(self, dic): | ||||
| if not self.GKB_flag: | if not self.GKB_flag: | ||||
| return 0 | return 0 | ||||
| @@ -88,17 +111,18 @@ class KBBase(ABC): | |||||
| return 0 | return 0 | ||||
| else: | else: | ||||
| return sum(self._dict_len(v) for v in self.base.values()) | return sum(self._dict_len(v) for v in self.base.values()) | ||||
| class ground_KB(KBBase): | class ground_KB(KBBase): | ||||
| def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): | def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): | ||||
| super().__init__(pseudo_label_list, max_err) | super().__init__(pseudo_label_list, max_err) | ||||
| self.GKB_len_list = GKB_len_list | self.GKB_len_list = GKB_len_list | ||||
| self.base = {} | self.base = {} | ||||
| X, Y = self._get_GKB() | X, Y = self._get_GKB() | ||||
| for x, y in zip(X, Y): | for x, y in zip(X, Y): | ||||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | self.base.setdefault(len(x), defaultdict(list))[y].append(x) | ||||
| # For parallel version of _get_GKB | # For parallel version of _get_GKB | ||||
| def _get_XY_list(self, args): | def _get_XY_list(self, args): | ||||
| pre_x, post_x_it = args[0], args[1] | pre_x, post_x_it = args[0], args[1] | ||||
| @@ -114,6 +138,7 @@ class ground_KB(KBBase): | |||||
| def _get_GKB(self): | def _get_GKB(self): | ||||
| X, Y = [], [] | X, Y = [], [] | ||||
| for length in self.GKB_len_list: | for length in self.GKB_len_list: | ||||
| print("Generating GKB of length %d" % length) | |||||
| arg_list = [] | arg_list = [] | ||||
| for pre_x in self.pseudo_label_list: | for pre_x in self.pseudo_label_list: | ||||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | post_x_it = product(self.pseudo_label_list, repeat=length - 1) | ||||
| @@ -126,21 +151,24 @@ class ground_KB(KBBase): | |||||
| part_X, part_Y = zip(*XY_list) | part_X, part_Y = zip(*XY_list) | ||||
| X.extend(part_X) | X.extend(part_X) | ||||
| Y.extend(part_Y) | Y.extend(part_Y) | ||||
| if Y and isinstance(Y[0], (int, float)): | |||||
| if Y and isinstance(Y[0], (int, float)): | |||||
| X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | ||||
| return X, Y | return X, Y | ||||
| def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | |||||
| return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision) | |||||
| def _find_candidate_GKB(self, pred_res, y): | |||||
| def abduce_candidates(self, data_sample, max_revision_num, require_more_revision=0): | |||||
| return self._abduce_by_GKB( | |||||
| data_sample, max_revision_num, require_more_revision=require_more_revision | |||||
| ) | |||||
| def _find_candidate_GKB(self, cache_key, data_sample): | |||||
| y = data_sample["Y"][0] | |||||
| if self.max_err == 0: | if self.max_err == 0: | ||||
| return self.base[len(pred_res)][y] | |||||
| return self.base[cache_key][y] | |||||
| else: | else: | ||||
| potential_candidates = self.base[len(pred_res)] | |||||
| potential_candidates = self.base[cache_key] | |||||
| key_list = list(potential_candidates.keys()) | key_list = list(potential_candidates.keys()) | ||||
| key_idx = bisect.bisect_left(key_list, y) | key_idx = bisect.bisect_left(key_list, y) | ||||
| all_candidates = [] | all_candidates = [] | ||||
| for idx in range(key_idx - 1, -1, -1): | for idx in range(key_idx - 1, -1, -1): | ||||
| k = key_list[idx] | k = key_list[idx] | ||||
| @@ -148,7 +176,7 @@ class ground_KB(KBBase): | |||||
| all_candidates.extend(potential_candidates[k]) | all_candidates.extend(potential_candidates[k]) | ||||
| else: | else: | ||||
| break | break | ||||
| for idx in range(key_idx, len(key_list)): | for idx in range(key_idx, len(key_list)): | ||||
| k = key_list[idx] | k = key_list[idx] | ||||
| if abs(k - y) <= self.max_err: | if abs(k - y) <= self.max_err: | ||||
| @@ -156,19 +184,20 @@ class ground_KB(KBBase): | |||||
| else: | else: | ||||
| break | break | ||||
| return all_candidates | return all_candidates | ||||
| def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): | |||||
| if self.base == {} or len(pred_res) not in self.GKB_len_list: | |||||
| def _abduce_by_GKB(self, data_sample, max_revision_num, require_more_revision=0): | |||||
| cache_key = len(data_sample["pred_pseudo_label"][0]) | |||||
| if self.base == {} or cache_key not in self.GKB_len_list: | |||||
| return [] | return [] | ||||
| all_candidates = self._find_candidate_GKB(pred_res, y) | |||||
| all_candidates = self._find_candidate_GKB(cache_key, data_sample) | |||||
| if len(all_candidates) == 0: | if len(all_candidates) == 0: | ||||
| return [] | return [] | ||||
| cost_list = hamming_dist(pred_res, all_candidates) | |||||
| min_revision_num = np.min(cost_list) | |||||
| cost_array = hamming_dist(data_sample["pred_pseudo_label"][0], all_candidates) | |||||
| min_revision_num = np.min(cost_array) | |||||
| revision_num = min(max_revision_num, min_revision_num + require_more_revision) | revision_num = min(max_revision_num, min_revision_num + require_more_revision) | ||||
| idxs = np.where(cost_list <= revision_num)[0] | |||||
| idxs = np.where(cost_array <= revision_num)[0] | |||||
| candidates = [all_candidates[idx] for idx in idxs] | candidates = [all_candidates[idx] for idx in idxs] | ||||
| return candidates | return candidates | ||||
| @@ -180,33 +209,38 @@ class prolog_KB(KBBase): | |||||
| self.prolog.consult(pl_file) | self.prolog.consult(pl_file) | ||||
| def logic_forward(self, pseudo_labels): | def logic_forward(self, pseudo_labels): | ||||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] | |||||
| if result == 'true': | |||||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0][ | |||||
| "Res" | |||||
| ] | |||||
| if result == "true": | |||||
| return True | return True | ||||
| elif result == 'false': | |||||
| elif result == "false": | |||||
| return False | return False | ||||
| return result | return result | ||||
| def _revision_pred_res(self, pred_res, revision_idx): | def _revision_pred_res(self, pred_res, revision_idx): | ||||
| import re | import re | ||||
| revision_pred_res = pred_res.copy() | revision_pred_res = pred_res.copy() | ||||
| revision_pred_res = flatten(revision_pred_res) | revision_pred_res = flatten(revision_pred_res) | ||||
| for idx in revision_idx: | for idx in revision_idx: | ||||
| revision_pred_res[idx] = 'P' + str(idx) | |||||
| revision_pred_res[idx] = "P" + str(idx) | |||||
| revision_pred_res = reform_idx(revision_pred_res, pred_res) | revision_pred_res = reform_idx(revision_pred_res, pred_res) | ||||
| # TODO:不知道有没有更简洁的方法 | # TODO:不知道有没有更简洁的方法 | ||||
| regex = r"'P\d+'" | regex = r"'P\d+'" | ||||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res)) | |||||
| return re.sub( | |||||
| regex, lambda x: x.group().replace("'", ""), str(revision_pred_res) | |||||
| ) | |||||
| def get_query_string(self, pred_res, y, revision_idx): | def get_query_string(self, pred_res, y, revision_idx): | ||||
| query_string = "logic_forward(" | query_string = "logic_forward(" | ||||
| query_string += self._revision_pred_res(pred_res, revision_idx) | query_string += self._revision_pred_res(pred_res, revision_idx) | ||||
| key_is_none_flag = y is None or (type(y) == list and y[0] is None) | key_is_none_flag = y is None or (type(y) == list and y[0] is None) | ||||
| query_string += ",%s)." % y if not key_is_none_flag else ")." | query_string += ",%s)." % y if not key_is_none_flag else ")." | ||||
| return query_string | return query_string | ||||
| def revise_by_idx(self, pred_res, y, revision_idx): | def revise_by_idx(self, pred_res, y, revision_idx): | ||||
| candidates = [] | candidates = [] | ||||
| query_string = self.get_query_string(pred_res, y, revision_idx) | query_string = self.get_query_string(pred_res, y, revision_idx) | ||||
| @@ -70,7 +70,7 @@ class ReasonerBase: | |||||
| candidates = [[self.remapping[x] for x in c] for c in candidates] | candidates = [[self.remapping[x] for x in c] for c in candidates] | ||||
| return confidence_dist(pred_prob, candidates) | return confidence_dist(pred_prob, candidates) | ||||
| def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): | |||||
| def _get_one_candidate(self, data_sample, candidates): | |||||
| """ | """ | ||||
| Get one candidate. If multiple candidates exist, return the one with minimum cost. | Get one candidate. If multiple candidates exist, return the one with minimum cost. | ||||
| @@ -94,7 +94,9 @@ class ReasonerBase: | |||||
| elif len(candidates) == 1: | elif len(candidates) == 1: | ||||
| return candidates[0] | return candidates[0] | ||||
| else: | else: | ||||
| cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) | |||||
| cost_array = self._get_cost_list( | |||||
| data_sample["pred_pseudo_label"][0], data_sample["pred_prob"][0], candidates | |||||
| ) | |||||
| candidate = candidates[np.argmin(cost_array)] | candidate = candidates[np.argmin(cost_array)] | ||||
| return candidate | return candidate | ||||
| @@ -188,9 +190,7 @@ class ReasonerBase: | |||||
| """ | """ | ||||
| return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) | return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) | ||||
| def abduce( | |||||
| self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 | |||||
| ): | |||||
| def abduce(self, data_sample, max_revision=-1, require_more_revision=0): | |||||
| """ | """ | ||||
| Perform revision by abduction on the given data. | Perform revision by abduction on the given data. | ||||
| @@ -213,26 +213,24 @@ class ReasonerBase: | |||||
| list | list | ||||
| The abduced revisions. | The abduced revisions. | ||||
| """ | """ | ||||
| symbol_num = len(flatten(pred_pseudo_label)) | |||||
| symbol_num = len(flatten(data_sample.pred_pseudo_label)) | |||||
| max_revision_num = calculate_revision_num(max_revision, symbol_num) | max_revision_num = calculate_revision_num(max_revision, symbol_num) | ||||
| if self.use_zoopt: | if self.use_zoopt: | ||||
| solution = self.zoopt_get_solution( | solution = self.zoopt_get_solution( | ||||
| symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||||
| symbol_num, data_sample, max_revision_num | |||||
| ) | ) | ||||
| revision_idx = np.where(solution != 0)[0] | revision_idx = np.where(solution != 0)[0] | ||||
| candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||||
| candidates = self.revise_by_idx(data_sample, revision_idx) | |||||
| else: | else: | ||||
| candidates = self.kb.abduce_candidates( | candidates = self.kb.abduce_candidates( | ||||
| pred_pseudo_label, y, max_revision_num, require_more_revision | |||||
| data_sample, max_revision_num, require_more_revision | |||||
| ) | ) | ||||
| candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) | |||||
| candidate = self._get_one_candidate(data_sample, candidates) | |||||
| return candidate | return candidate | ||||
| def batch_abduce( | |||||
| self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||||
| ): | |||||
| def batch_abduce(self, data_samples, max_revision=-1, require_more_revision=0): | |||||
| """ | """ | ||||
| Perform abduction on the given data in batches. | Perform abduction on the given data in batches. | ||||
| @@ -255,14 +253,15 @@ class ReasonerBase: | |||||
| list | list | ||||
| The abduced revisions in batches. | The abduced revisions in batches. | ||||
| """ | """ | ||||
| return [ | |||||
| abduced_pseudo_label = [ | |||||
| self.abduce( | self.abduce( | ||||
| _pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision | |||||
| ) | |||||
| for _pred_prob, _pred_pseudo_label, _Y in zip( | |||||
| pred_prob, pred_pseudo_label, Y | |||||
| data_sample, | |||||
| max_revision=max_revision, | |||||
| require_more_revision=require_more_revision, | |||||
| ) | ) | ||||
| for data_sample in data_samples | |||||
| ] | ] | ||||
| data_samples.abduced_pseudo_label = abduced_pseudo_label | |||||
| # def _batch_abduce_helper(self, args): | # def _batch_abduce_helper(self, args): | ||||
| # z, prob, y, max_revision, require_more_revision = args | # z, prob, y, max_revision, require_more_revision = args | ||||
| @@ -281,43 +280,57 @@ class ReasonerBase: | |||||
| ) | ) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| from kb import KBBase, ground_KB, prolog_KB | from kb import KBBase, ground_KB, prolog_KB | ||||
| prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||||
| prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||||
| prob1 = [ | |||||
| [ | |||||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ] | |||||
| ] | |||||
| prob2 = [ | |||||
| [ | |||||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ] | |||||
| ] | |||||
| class add_KB(KBBase): | class add_KB(KBBase): | ||||
| def __init__(self, pseudo_label_list=list(range(10)), | |||||
| use_cache=True): | |||||
| def __init__(self, pseudo_label_list=list(range(10)), use_cache=True): | |||||
| super().__init__(pseudo_label_list, use_cache=use_cache) | super().__init__(pseudo_label_list, use_cache=use_cache) | ||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| class add_ground_KB(ground_KB): | class add_ground_KB(ground_KB): | ||||
| def __init__(self, pseudo_label_list=list(range(10)), | |||||
| GKB_len_list=[2]): | |||||
| def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||||
| super().__init__(pseudo_label_list, GKB_len_list) | super().__init__(pseudo_label_list, GKB_len_list) | ||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| def test_add(reasoner): | def test_add(reasoner): | ||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
| res = reasoner.batch_abduce( | |||||
| prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
| res = reasoner.batch_abduce( | |||||
| prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||||
| res = reasoner.batch_abduce( | |||||
| prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||||
| res = reasoner.batch_abduce( | |||||
| prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||||
| res = reasoner.batch_abduce( | |||||
| prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| print() | print() | ||||
| @@ -337,8 +350,9 @@ if __name__ == "__main__": | |||||
| test_add(reasoner) | test_add(reasoner) | ||||
| print("prolog_KB with add.pl:") | print("prolog_KB with add.pl:") | ||||
| kb = prolog_KB(pseudo_label_list=list(range(10)), | |||||
| pl_file="examples/mnist_add/datasets/add.pl") | |||||
| kb = prolog_KB( | |||||
| pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl" | |||||
| ) | |||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| test_add(reasoner) | test_add(reasoner) | ||||
| @@ -351,14 +365,16 @@ if __name__ == "__main__": | |||||
| test_add(reasoner) | test_add(reasoner) | ||||
| print("add_KB with multiple inputs at once:") | print("add_KB with multiple inputs at once:") | ||||
| multiple_prob = [[ | |||||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ], | |||||
| [ | |||||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ]] | |||||
| multiple_prob = [ | |||||
| [ | |||||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ], | |||||
| [ | |||||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ], | |||||
| ] | |||||
| kb = add_KB() | kb = add_KB() | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| @@ -383,8 +399,21 @@ if __name__ == "__main__": | |||||
| class HWF_KB(KBBase): | class HWF_KB(KBBase): | ||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||||
| "+", "-", "times", "div"], | |||||
| pseudo_label_list=[ | |||||
| "1", | |||||
| "2", | |||||
| "3", | |||||
| "4", | |||||
| "5", | |||||
| "6", | |||||
| "7", | |||||
| "8", | |||||
| "9", | |||||
| "+", | |||||
| "-", | |||||
| "times", | |||||
| "div", | |||||
| ], | |||||
| max_err=1e-3, | max_err=1e-3, | ||||
| ): | ): | ||||
| super().__init__(pseudo_label_list, max_err) | super().__init__(pseudo_label_list, max_err) | ||||
| @@ -393,7 +422,17 @@ if __name__ == "__main__": | |||||
| if len(formula) % 2 == 0: | if len(formula) % 2 == 0: | ||||
| return False | return False | ||||
| for i in range(len(formula)): | for i in range(len(formula)): | ||||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||||
| if i % 2 == 0 and formula[i] not in [ | |||||
| "1", | |||||
| "2", | |||||
| "3", | |||||
| "4", | |||||
| "5", | |||||
| "6", | |||||
| "7", | |||||
| "8", | |||||
| "9", | |||||
| ]: | |||||
| return False | return False | ||||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | ||||
| return False | return False | ||||
| @@ -406,12 +445,25 @@ if __name__ == "__main__": | |||||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | ||||
| formula = [mapping[f] for f in formula] | formula = [mapping[f] for f in formula] | ||||
| return eval("".join(formula)) | return eval("".join(formula)) | ||||
| class HWF_ground_KB(ground_KB): | class HWF_ground_KB(ground_KB): | ||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||||
| "+", "-", "times", "div"], | |||||
| pseudo_label_list=[ | |||||
| "1", | |||||
| "2", | |||||
| "3", | |||||
| "4", | |||||
| "5", | |||||
| "6", | |||||
| "7", | |||||
| "8", | |||||
| "9", | |||||
| "+", | |||||
| "-", | |||||
| "times", | |||||
| "div", | |||||
| ], | |||||
| GKB_len_list=[1, 3, 5, 7], | GKB_len_list=[1, 3, 5, 7], | ||||
| max_err=1e-3, | max_err=1e-3, | ||||
| ): | ): | ||||
| @@ -421,7 +473,17 @@ if __name__ == "__main__": | |||||
| if len(formula) % 2 == 0: | if len(formula) % 2 == 0: | ||||
| return False | return False | ||||
| for i in range(len(formula)): | for i in range(len(formula)): | ||||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||||
| if i % 2 == 0 and formula[i] not in [ | |||||
| "1", | |||||
| "2", | |||||
| "3", | |||||
| "4", | |||||
| "5", | |||||
| "6", | |||||
| "7", | |||||
| "8", | |||||
| "9", | |||||
| ]: | |||||
| return False | return False | ||||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | ||||
| return False | return False | ||||
| @@ -434,7 +496,7 @@ if __name__ == "__main__": | |||||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | ||||
| formula = [mapping[f] for f in formula] | formula = [mapping[f] for f in formula] | ||||
| return eval("".join(formula)) | return eval("".join(formula)) | ||||
| def test_hwf(reasoner): | def test_hwf(reasoner): | ||||
| res = reasoner.batch_abduce( | res = reasoner.batch_abduce( | ||||
| [None], | [None], | ||||
| @@ -461,7 +523,7 @@ if __name__ == "__main__": | |||||
| ) | ) | ||||
| print(res) | print(res) | ||||
| print() | print() | ||||
| def test_hwf_multiple(reasoner, max_revisions): | def test_hwf_multiple(reasoner, max_revisions): | ||||
| res = reasoner.batch_abduce( | res = reasoner.batch_abduce( | ||||
| [None, None], | [None, None], | ||||
| @@ -512,10 +574,10 @@ if __name__ == "__main__": | |||||
| print("HWF_KB with multiple inputs at once:") | print("HWF_KB with multiple inputs at once:") | ||||
| kb = HWF_KB(max_err=0.1) | kb = HWF_KB(max_err=0.1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| test_hwf_multiple(reasoner, max_revisions=[1,3,3]) | |||||
| test_hwf_multiple(reasoner, max_revisions=[1, 3, 3]) | |||||
| print("max_revision is float") | print("max_revision is float") | ||||
| test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) | |||||
| test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9]) | |||||
| class HED_prolog_KB(prolog_KB): | class HED_prolog_KB(prolog_KB): | ||||
| def __init__(self, pseudo_label_list, pl_file): | def __init__(self, pseudo_label_list, pl_file): | ||||
| @@ -0,0 +1,2 @@ | |||||
| from .base_data_element import BaseDataElement | |||||
| from .list_data import ListData | |||||
| @@ -0,0 +1,629 @@ | |||||
| # Copyright (c) OpenMMLab. All rights reserved. | |||||
| import copy | |||||
| from typing import Any, Iterator, Optional, Tuple, Type, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| class BaseDataElement: | |||||
| """A base data interface that supports Tensor-like and dict-like | |||||
| operations. | |||||
| A typical data elements refer to predicted results or ground truth labels | |||||
| on a task, such as predicted bboxes, instance masks, semantic | |||||
| segmentation masks, etc. Because groundtruth labels and predicted results | |||||
| often have similar properties (for example, the predicted bboxes and the | |||||
| groundtruth bboxes), MMEngine uses the same abstract data interface to | |||||
| encapsulate predicted results and groundtruth labels, and it is recommended | |||||
| to use different name conventions to distinguish them, such as using | |||||
| ``gt_instances`` and ``pred_instances`` to distinguish between labels and | |||||
| predicted results. Additionally, we distinguish data elements at instance | |||||
| level, pixel level, and label level. Each of these types has its own | |||||
| characteristics. Therefore, MMEngine defines the base class | |||||
| ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and | |||||
| ``LabelData`` inheriting from ``BaseDataElement`` to represent different | |||||
| types of ground truth labels or predictions. | |||||
| Another common data element is sample data. A sample data consists of input | |||||
| data (such as an image) and its annotations and predictions. In general, | |||||
| an image can have multiple types of annotations and/or predictions at the | |||||
| same time (for example, both pixel-level semantic segmentation annotations | |||||
| and instance-level detection bboxes annotations). All labels and | |||||
| predictions of a training sample are often passed between Dataset, Model, | |||||
| Visualizer, and Evaluator components. In order to simplify the interface | |||||
| between components, we can treat them as a large data element and | |||||
| encapsulate them. Such data elements are generally called XXDataSample in | |||||
| the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |||||
| allows `BaseDataElement` as its attribute. Such a class generally | |||||
| encapsulates all the data of a sample in the algorithm library, and its | |||||
| attributes generally are various types of data elements. For example, | |||||
| MMDetection is assigned by the BaseDataElement to encapsulate all the data | |||||
| elements of the sample labeling and prediction of a sample in the | |||||
| algorithm library. | |||||
| The attributes in ``BaseDataElement`` are divided into two parts, | |||||
| the ``metainfo`` and the ``data`` respectively. | |||||
| - ``metainfo``: Usually contains the | |||||
| information about the image such as filename, | |||||
| image_shape, pad_shape, etc. The attributes can be accessed or | |||||
| modified by dict-like or object-like operations, such as | |||||
| ``.`` (for data access and modification), ``in``, ``del``, | |||||
| ``pop(str)``, ``get(str)``, ``metainfo_keys()``, | |||||
| ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for | |||||
| set or change key-value pairs in metainfo). | |||||
| - ``data``: Annotations or model predictions are | |||||
| stored. The attributes can be accessed or modified by | |||||
| dict-like or object-like operations, such as | |||||
| ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |||||
| ``values()``, ``items()``. Users can also apply tensor-like | |||||
| methods to all :obj:`torch.Tensor` in the ``data_fields``, | |||||
| such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, | |||||
| ``to_tensor()``, ``.detach()``. | |||||
| Args: | |||||
| metainfo (dict, optional): A dict contains the meta information | |||||
| of single image, such as ``dict(img_shape=(512, 512, 3), | |||||
| scale_factor=(1, 1, 1, 1))``. Defaults to None. | |||||
| kwargs (dict, optional): A dict contains annotations of single image or | |||||
| model predictions. Defaults to None. | |||||
| Examples: | |||||
| >>> import torch | |||||
| >>> from mmengine.structures import BaseDataElement | |||||
| >>> gt_instances = BaseDataElement() | |||||
| >>> bboxes = torch.rand((5, 4)) | |||||
| >>> scores = torch.rand((5,)) | |||||
| >>> img_id = 0 | |||||
| >>> img_shape = (800, 1333) | |||||
| >>> gt_instances = BaseDataElement( | |||||
| ... metainfo=dict(img_id=img_id, img_shape=img_shape), | |||||
| ... bboxes=bboxes, scores=scores) | |||||
| >>> gt_instances = BaseDataElement( | |||||
| ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) | |||||
| >>> # new | |||||
| >>> gt_instances1 = gt_instances.new( | |||||
| ... metainfo=dict(img_id=1, img_shape=(640, 640)), | |||||
| ... bboxes=torch.rand((5, 4)), | |||||
| ... scores=torch.rand((5,))) | |||||
| >>> gt_instances2 = gt_instances1.new() | |||||
| >>> # add and process property | |||||
| >>> gt_instances = BaseDataElement() | |||||
| >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) | |||||
| >>> assert 'img_shape' in gt_instances.metainfo_keys() | |||||
| >>> assert 'img_shape' in gt_instances | |||||
| >>> assert 'img_shape' not in gt_instances.keys() | |||||
| >>> assert 'img_shape' in gt_instances.all_keys() | |||||
| >>> print(gt_instances.img_shape) | |||||
| (100, 100) | |||||
| >>> gt_instances.scores = torch.rand((5,)) | |||||
| >>> assert 'scores' in gt_instances.keys() | |||||
| >>> assert 'scores' in gt_instances | |||||
| >>> assert 'scores' in gt_instances.all_keys() | |||||
| >>> assert 'scores' not in gt_instances.metainfo_keys() | |||||
| >>> print(gt_instances.scores) | |||||
| tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) | |||||
| >>> gt_instances.bboxes = torch.rand((5, 4)) | |||||
| >>> assert 'bboxes' in gt_instances.keys() | |||||
| >>> assert 'bboxes' in gt_instances | |||||
| >>> assert 'bboxes' in gt_instances.all_keys() | |||||
| >>> assert 'bboxes' not in gt_instances.metainfo_keys() | |||||
| >>> print(gt_instances.bboxes) | |||||
| tensor([[0.0900, 0.0424, 0.1755, 0.4469], | |||||
| [0.8648, 0.0592, 0.3484, 0.0913], | |||||
| [0.5808, 0.1909, 0.6165, 0.7088], | |||||
| [0.5490, 0.4209, 0.9416, 0.2374], | |||||
| [0.3652, 0.1218, 0.8805, 0.7523]]) | |||||
| >>> # delete and change property | |||||
| >>> gt_instances = BaseDataElement( | |||||
| ... metainfo=dict(img_id=0, img_shape=(640, 640)), | |||||
| ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) | |||||
| >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) | |||||
| >>> gt_instances.img_shape # (1280, 1280) | |||||
| >>> gt_instances.bboxes = gt_instances.bboxes * 2 | |||||
| >>> gt_instances.get('img_shape', None) # (1280, 1280) | |||||
| >>> gt_instances.get('bboxes', None) # 6x4 tensor | |||||
| >>> del gt_instances.img_shape | |||||
| >>> del gt_instances.bboxes | |||||
| >>> assert 'img_shape' not in gt_instances | |||||
| >>> assert 'bboxes' not in gt_instances | |||||
| >>> gt_instances.pop('img_shape', None) # None | |||||
| >>> gt_instances.pop('bboxes', None) # None | |||||
| >>> # Tensor-like | |||||
| >>> cuda_instances = gt_instances.cuda() | |||||
| >>> cuda_instances = gt_instances.to('cuda:0') | |||||
| >>> cpu_instances = cuda_instances.cpu() | |||||
| >>> cpu_instances = cuda_instances.to('cpu') | |||||
| >>> fp16_instances = cuda_instances.to( | |||||
| ... device=None, dtype=torch.float16, non_blocking=False, | |||||
| ... copy=False, memory_format=torch.preserve_format) | |||||
| >>> cpu_instances = cuda_instances.detach() | |||||
| >>> np_instances = cpu_instances.numpy() | |||||
| >>> metainfo = dict(img_shape=(800, 1196, 3)) | |||||
| >>> gt_instances = BaseDataElement( | |||||
| ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |||||
| >>> sample = BaseDataElement(metainfo=metainfo, | |||||
| ... gt_instances=gt_instances) | |||||
| >>> print(sample) | |||||
| <BaseDataElement( | |||||
| META INFORMATION | |||||
| img_shape: (800, 1196, 3) | |||||
| DATA FIELDS | |||||
| gt_instances: <BaseDataElement( | |||||
| META INFORMATION | |||||
| img_shape: (800, 1196, 3) | |||||
| DATA FIELDS | |||||
| det_labels: tensor([0, 1, 2, 3]) | |||||
| ) at 0x7f0ec5eadc70> | |||||
| ) at 0x7f0fea49e130> | |||||
| >>> # inheritance | |||||
| >>> class DetDataSample(BaseDataElement): | |||||
| ... @property | |||||
| ... def proposals(self): | |||||
| ... return self._proposals | |||||
| ... @proposals.setter | |||||
| ... def proposals(self, value): | |||||
| ... self.set_field(value, '_proposals', dtype=BaseDataElement) | |||||
| ... @proposals.deleter | |||||
| ... def proposals(self): | |||||
| ... del self._proposals | |||||
| ... @property | |||||
| ... def gt_instances(self): | |||||
| ... return self._gt_instances | |||||
| ... @gt_instances.setter | |||||
| ... def gt_instances(self, value): | |||||
| ... self.set_field(value, '_gt_instances', | |||||
| ... dtype=BaseDataElement) | |||||
| ... @gt_instances.deleter | |||||
| ... def gt_instances(self): | |||||
| ... del self._gt_instances | |||||
| ... @property | |||||
| ... def pred_instances(self): | |||||
| ... return self._pred_instances | |||||
| ... @pred_instances.setter | |||||
| ... def pred_instances(self, value): | |||||
| ... self.set_field(value, '_pred_instances', | |||||
| ... dtype=BaseDataElement) | |||||
| ... @pred_instances.deleter | |||||
| ... def pred_instances(self): | |||||
| ... del self._pred_instances | |||||
| >>> det_sample = DetDataSample() | |||||
| >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |||||
| >>> det_sample.proposals = proposals | |||||
| >>> assert 'proposals' in det_sample | |||||
| >>> assert det_sample.proposals == proposals | |||||
| >>> del det_sample.proposals | |||||
| >>> assert 'proposals' not in det_sample | |||||
| >>> with self.assertRaises(AssertionError): | |||||
| ... det_sample.proposals = torch.rand((5, 4)) | |||||
| """ | |||||
| def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |||||
| self._metainfo_fields: set = set() | |||||
| self._data_fields: set = set() | |||||
| if metainfo is not None: | |||||
| self.set_metainfo(metainfo=metainfo) | |||||
| if kwargs: | |||||
| self.set_data(kwargs) | |||||
| def set_metainfo(self, metainfo: dict) -> None: | |||||
| """Set or change key-value pairs in ``metainfo_field`` by parameter | |||||
| ``metainfo``. | |||||
| Args: | |||||
| metainfo (dict): A dict contains the meta information | |||||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||||
| """ | |||||
| assert isinstance( | |||||
| metainfo, dict | |||||
| ), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||||
| # meta = copy.deepcopy(metainfo) | |||||
| for k, v in metainfo.items(): | |||||
| self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | |||||
| def set_data(self, data: dict) -> None: | |||||
| """Set or change key-value pairs in ``data_field`` by parameter | |||||
| ``data``. | |||||
| Args: | |||||
| data (dict): A dict contains annotations of image or | |||||
| model predictions. | |||||
| """ | |||||
| assert isinstance(data, dict), f"data should be a `dict` but got {data}" | |||||
| for k, v in data.items(): | |||||
| # Use `setattr()` rather than `self.set_field` to allow `set_data` | |||||
| # to set property method. | |||||
| setattr(self, k, v) | |||||
| def update(self, instance: "BaseDataElement") -> None: | |||||
| """The update() method updates the BaseDataElement with the elements | |||||
| from another BaseDataElement object. | |||||
| Args: | |||||
| instance (BaseDataElement): Another BaseDataElement object for | |||||
| update the current object. | |||||
| """ | |||||
| assert isinstance( | |||||
| instance, BaseDataElement | |||||
| ), f"instance should be a `BaseDataElement` but got {type(instance)}" | |||||
| self.set_metainfo(dict(instance.metainfo_items())) | |||||
| self.set_data(dict(instance.items())) | |||||
| def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": | |||||
| """Return a new data element with same type. If ``metainfo`` and | |||||
| ``data`` are None, the new data element will have same metainfo and | |||||
| data. If metainfo or data is not None, the new result will overwrite it | |||||
| with the input value. | |||||
| Args: | |||||
| metainfo (dict, optional): A dict contains the meta information | |||||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||||
| Defaults to None. | |||||
| kwargs (dict): A dict contains annotations of image or | |||||
| model predictions. | |||||
| Returns: | |||||
| BaseDataElement: A new data element with same type. | |||||
| """ | |||||
| new_data = self.__class__() | |||||
| if metainfo is not None: | |||||
| new_data.set_metainfo(metainfo) | |||||
| else: | |||||
| new_data.set_metainfo(dict(self.metainfo_items())) | |||||
| if kwargs: | |||||
| new_data.set_data(kwargs) | |||||
| else: | |||||
| new_data.set_data(dict(self.items())) | |||||
| return new_data | |||||
| def clone(self): | |||||
| """Deep copy the current data element. | |||||
| Returns: | |||||
| BaseDataElement: The copy of current data element. | |||||
| """ | |||||
| clone_data = self.__class__() | |||||
| clone_data.set_metainfo(dict(self.metainfo_items())) | |||||
| clone_data.set_data(dict(self.items())) | |||||
| return clone_data | |||||
| def keys(self) -> list: | |||||
| """ | |||||
| Returns: | |||||
| list: Contains all keys in data_fields. | |||||
| """ | |||||
| # We assume that the name of the attribute related to property is | |||||
| # '_' + the name of the property. We use this rule to filter out | |||||
| # private keys. | |||||
| # TODO: Use a more robust way to solve this problem | |||||
| private_keys = { | |||||
| "_" + key | |||||
| for key in self._data_fields | |||||
| if isinstance(getattr(type(self), key, None), property) | |||||
| } | |||||
| return list(self._data_fields - private_keys) | |||||
| def metainfo_keys(self) -> list: | |||||
| """ | |||||
| Returns: | |||||
| list: Contains all keys in metainfo_fields. | |||||
| """ | |||||
| return list(self._metainfo_fields) | |||||
| def values(self) -> list: | |||||
| """ | |||||
| Returns: | |||||
| list: Contains all values in data. | |||||
| """ | |||||
| return [getattr(self, k) for k in self.keys()] | |||||
| def metainfo_values(self) -> list: | |||||
| """ | |||||
| Returns: | |||||
| list: Contains all values in metainfo. | |||||
| """ | |||||
| return [getattr(self, k) for k in self.metainfo_keys()] | |||||
| def all_keys(self) -> list: | |||||
| """ | |||||
| Returns: | |||||
| list: Contains all keys in metainfo and data. | |||||
| """ | |||||
| return self.metainfo_keys() + self.keys() | |||||
| def all_values(self) -> list: | |||||
| """ | |||||
| Returns: | |||||
| list: Contains all values in metainfo and data. | |||||
| """ | |||||
| return self.metainfo_values() + self.values() | |||||
| def all_items(self) -> Iterator[Tuple[str, Any]]: | |||||
| """ | |||||
| Returns: | |||||
| iterator: An iterator object whose element is (key, value) tuple | |||||
| pairs for ``metainfo`` and ``data``. | |||||
| """ | |||||
| for k in self.all_keys(): | |||||
| yield (k, getattr(self, k)) | |||||
| def items(self) -> Iterator[Tuple[str, Any]]: | |||||
| """ | |||||
| Returns: | |||||
| iterator: An iterator object whose element is (key, value) tuple | |||||
| pairs for ``data``. | |||||
| """ | |||||
| for k in self.keys(): | |||||
| yield (k, getattr(self, k)) | |||||
| def metainfo_items(self) -> Iterator[Tuple[str, Any]]: | |||||
| """ | |||||
| Returns: | |||||
| iterator: An iterator object whose element is (key, value) tuple | |||||
| pairs for ``metainfo``. | |||||
| """ | |||||
| for k in self.metainfo_keys(): | |||||
| yield (k, getattr(self, k)) | |||||
| @property | |||||
| def metainfo(self) -> dict: | |||||
| """dict: A dict contains metainfo of current data element.""" | |||||
| return dict(self.metainfo_items()) | |||||
| def __setattr__(self, name: str, value: Any): | |||||
| """setattr is only used to set data.""" | |||||
| if name in ("_metainfo_fields", "_data_fields"): | |||||
| if not hasattr(self, name): | |||||
| super().__setattr__(name, value) | |||||
| else: | |||||
| raise AttributeError( | |||||
| f"{name} has been used as a " | |||||
| "private attribute, which is immutable." | |||||
| ) | |||||
| else: | |||||
| self.set_field(name=name, value=value, field_type="data", dtype=None) | |||||
| def __delattr__(self, item: str): | |||||
| """Delete the item in dataelement. | |||||
| Args: | |||||
| item (str): The key to delete. | |||||
| """ | |||||
| if item in ("_metainfo_fields", "_data_fields"): | |||||
| raise AttributeError( | |||||
| f"{item} has been used as a " "private attribute, which is immutable." | |||||
| ) | |||||
| super().__delattr__(item) | |||||
| if item in self._metainfo_fields: | |||||
| self._metainfo_fields.remove(item) | |||||
| elif item in self._data_fields: | |||||
| self._data_fields.remove(item) | |||||
| # dict-like methods | |||||
| __delitem__ = __delattr__ | |||||
| def get(self, key, default=None) -> Any: | |||||
| """Get property in data and metainfo as the same as python.""" | |||||
| # Use `getattr()` rather than `self.__dict__.get()` to allow getting | |||||
| # properties. | |||||
| return getattr(self, key, default) | |||||
| def pop(self, *args) -> Any: | |||||
| """Pop property in data and metainfo as the same as python.""" | |||||
| assert len(args) < 3, "``pop`` get more than 2 arguments" | |||||
| name = args[0] | |||||
| if name in self._metainfo_fields: | |||||
| self._metainfo_fields.remove(args[0]) | |||||
| return self.__dict__.pop(*args) | |||||
| elif name in self._data_fields: | |||||
| self._data_fields.remove(args[0]) | |||||
| return self.__dict__.pop(*args) | |||||
| # with default value | |||||
| elif len(args) == 2: | |||||
| return args[1] | |||||
| else: | |||||
| # don't just use 'self.__dict__.pop(*args)' for only popping key in | |||||
| # metainfo or data | |||||
| raise KeyError(f"{args[0]} is not contained in metainfo or data") | |||||
| def __contains__(self, item: str) -> bool: | |||||
| """Whether the item is in dataelement. | |||||
| Args: | |||||
| item (str): The key to inquire. | |||||
| """ | |||||
| return item in self._data_fields or item in self._metainfo_fields | |||||
| def set_field( | |||||
| self, | |||||
| value: Any, | |||||
| name: str, | |||||
| dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |||||
| field_type: str = "data", | |||||
| ) -> None: | |||||
| """Special method for set union field, used as property.setter | |||||
| functions.""" | |||||
| assert field_type in ["metainfo", "data"] | |||||
| if dtype is not None: | |||||
| assert isinstance( | |||||
| value, dtype | |||||
| ), f"{value} should be a {dtype} but got {type(value)}" | |||||
| if field_type == "metainfo": | |||||
| if name in self._data_fields: | |||||
| raise AttributeError( | |||||
| f"Cannot set {name} to be a field of metainfo " | |||||
| f"because {name} is already a data field" | |||||
| ) | |||||
| self._metainfo_fields.add(name) | |||||
| else: | |||||
| if name in self._metainfo_fields: | |||||
| raise AttributeError( | |||||
| f"Cannot set {name} to be a field of data " | |||||
| f"because {name} is already a metainfo field" | |||||
| ) | |||||
| self._data_fields.add(name) | |||||
| super().__setattr__(name, value) | |||||
| # Tensor-like methods | |||||
| def to(self, *args, **kwargs) -> "BaseDataElement": | |||||
| """Apply same name function to all tensors in data_fields.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| if hasattr(v, "to"): | |||||
| v = v.to(*args, **kwargs) | |||||
| data = {k: v} | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| # Tensor-like methods | |||||
| def cpu(self) -> "BaseDataElement": | |||||
| """Convert all tensors to CPU in data.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
| v = v.cpu() | |||||
| data = {k: v} | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| # Tensor-like methods | |||||
| def cuda(self) -> "BaseDataElement": | |||||
| """Convert all tensors to GPU in data.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
| v = v.cuda() | |||||
| data = {k: v} | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| # Tensor-like methods | |||||
| def npu(self) -> "BaseDataElement": | |||||
| """Convert all tensors to NPU in data.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
| v = v.npu() | |||||
| data = {k: v} | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| def mlu(self) -> "BaseDataElement": | |||||
| """Convert all tensors to MLU in data.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
| v = v.mlu() | |||||
| data = {k: v} | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| # Tensor-like methods | |||||
| def detach(self) -> "BaseDataElement": | |||||
| """Detach all tensors in data.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
| v = v.detach() | |||||
| data = {k: v} | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| # Tensor-like methods | |||||
| def numpy(self) -> "BaseDataElement": | |||||
| """Convert all tensors to np.ndarray in data.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
| v = v.detach().cpu().numpy() | |||||
| data = {k: v} | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| def to_tensor(self) -> "BaseDataElement": | |||||
| """Convert all np.ndarray to tensor in data.""" | |||||
| new_data = self.new() | |||||
| for k, v in self.items(): | |||||
| data = {} | |||||
| if isinstance(v, np.ndarray): | |||||
| v = torch.from_numpy(v) | |||||
| data[k] = v | |||||
| elif isinstance(v, BaseDataElement): | |||||
| v = v.to_tensor() | |||||
| data[k] = v | |||||
| new_data.set_data(data) | |||||
| return new_data | |||||
| def to_dict(self) -> dict: | |||||
| """Convert BaseDataElement to dict.""" | |||||
| return { | |||||
| k: v.to_dict() if isinstance(v, BaseDataElement) else v | |||||
| for k, v in self.all_items() | |||||
| } | |||||
| def __repr__(self) -> str: | |||||
| """Represent the object.""" | |||||
| def _addindent(s_: str, num_spaces: int) -> str: | |||||
| """This func is modified from `pytorch` https://github.com/pytorch/ | |||||
| pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu | |||||
| les/module.py#L29. | |||||
| Args: | |||||
| s_ (str): The string to add spaces. | |||||
| num_spaces (int): The num of space to add. | |||||
| Returns: | |||||
| str: The string after add indent. | |||||
| """ | |||||
| s = s_.split("\n") | |||||
| # don't do anything for single-line stuff | |||||
| if len(s) == 1: | |||||
| return s_ | |||||
| first = s.pop(0) | |||||
| s = [(num_spaces * " ") + line for line in s] | |||||
| s = "\n".join(s) # type: ignore | |||||
| s = first + "\n" + s # type: ignore | |||||
| return s # type: ignore | |||||
| def dump(obj: Any) -> str: | |||||
| """Represent the object. | |||||
| Args: | |||||
| obj (Any): The obj to represent. | |||||
| Returns: | |||||
| str: The represented str. | |||||
| """ | |||||
| _repr = "" | |||||
| if isinstance(obj, dict): | |||||
| for k, v in obj.items(): | |||||
| _repr += f"\n{k}: {_addindent(dump(v), 4)}" | |||||
| elif isinstance(obj, BaseDataElement): | |||||
| _repr += "\n\n META INFORMATION" | |||||
| metainfo_items = dict(obj.metainfo_items()) | |||||
| _repr += _addindent(dump(metainfo_items), 4) | |||||
| _repr += "\n\n DATA FIELDS" | |||||
| items = dict(obj.items()) | |||||
| _repr += _addindent(dump(items), 4) | |||||
| classname = obj.__class__.__name__ | |||||
| _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" | |||||
| else: | |||||
| _repr += repr(obj) | |||||
| return _repr | |||||
| return dump(self) | |||||
| @@ -0,0 +1,301 @@ | |||||
| # Copyright (c) OpenMMLab. All rights reserved. | |||||
| import itertools | |||||
| from collections.abc import Sized | |||||
| from typing import Any, List, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| from .base_data_element import BaseDataElement | |||||
| BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |||||
| LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |||||
| IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] | |||||
| # Modified from | |||||
| # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||||
| class ListData(BaseDataElement): | |||||
| """Data structure for instance-level annotations or predictions. | |||||
| Subclass of :class:`BaseDataElement`. All value in `data_fields` | |||||
| should have the same length. This design refer to | |||||
| https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |||||
| ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value | |||||
| in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, | |||||
| and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. | |||||
| Examples: | |||||
| >>> # custom data structure | |||||
| >>> class TmpObject: | |||||
| ... def __init__(self, tmp) -> None: | |||||
| ... assert isinstance(tmp, list) | |||||
| ... self.tmp = tmp | |||||
| ... def __len__(self): | |||||
| ... return len(self.tmp) | |||||
| ... def __getitem__(self, item): | |||||
| ... if isinstance(item, int): | |||||
| ... if item >= len(self) or item < -len(self): # type:ignore | |||||
| ... raise IndexError(f'Index {item} out of range!') | |||||
| ... else: | |||||
| ... # keep the dimension | |||||
| ... item = slice(item, None, len(self)) | |||||
| ... return TmpObject(self.tmp[item]) | |||||
| ... @staticmethod | |||||
| ... def cat(tmp_objs): | |||||
| ... assert all(isinstance(results, TmpObject) for results in tmp_objs) | |||||
| ... if len(tmp_objs) == 1: | |||||
| ... return tmp_objs[0] | |||||
| ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] | |||||
| ... tmp_list = list(itertools.chain(*tmp_list)) | |||||
| ... new_data = TmpObject(tmp_list) | |||||
| ... return new_data | |||||
| ... def __repr__(self): | |||||
| ... return str(self.tmp) | |||||
| >>> from mmengine.structures import ListData | |||||
| >>> import numpy as np | |||||
| >>> import torch | |||||
| >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |||||
| >>> instance_data = ListData(metainfo=img_meta) | |||||
| >>> 'img_shape' in instance_data | |||||
| True | |||||
| >>> instance_data.det_labels = torch.LongTensor([2, 3]) | |||||
| >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) | |||||
| >>> instance_data.bboxes = torch.rand((2, 4)) | |||||
| >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) | |||||
| >>> len(instance_data) | |||||
| 2 | |||||
| >>> print(instance_data) | |||||
| <ListData( | |||||
| META INFORMATION | |||||
| img_shape: (800, 1196, 3) | |||||
| pad_shape: (800, 1216, 3) | |||||
| DATA FIELDS | |||||
| det_labels: tensor([2, 3]) | |||||
| det_scores: tensor([0.8000, 0.7000]) | |||||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||||
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |||||
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] | |||||
| ) at 0x7fb492de6280> | |||||
| >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] | |||||
| >>> sorted_results.det_scores | |||||
| tensor([0.7000, 0.8000]) | |||||
| >>> print(instance_data[instance_data.det_scores > 0.75]) | |||||
| <ListData( | |||||
| META INFORMATION | |||||
| img_shape: (800, 1196, 3) | |||||
| pad_shape: (800, 1216, 3) | |||||
| DATA FIELDS | |||||
| det_labels: tensor([2]) | |||||
| det_scores: tensor([0.8000]) | |||||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) | |||||
| polygons: [[1, 2, 3, 4]] | |||||
| ) at 0x7f64ecf0ec40> | |||||
| >>> print(instance_data[instance_data.det_scores > 1]) | |||||
| <ListData( | |||||
| META INFORMATION | |||||
| img_shape: (800, 1196, 3) | |||||
| pad_shape: (800, 1216, 3) | |||||
| DATA FIELDS | |||||
| det_labels: tensor([], dtype=torch.int64) | |||||
| det_scores: tensor([]) | |||||
| bboxes: tensor([], size=(0, 4)) | |||||
| polygons: [] | |||||
| ) at 0x7f660a6a7f70> | |||||
| >>> print(instance_data.cat([instance_data, instance_data])) | |||||
| <ListData( | |||||
| META INFORMATION | |||||
| img_shape: (800, 1196, 3) | |||||
| pad_shape: (800, 1216, 3) | |||||
| DATA FIELDS | |||||
| det_labels: tensor([2, 3, 2, 3]) | |||||
| det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) | |||||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||||
| [0.8101, 0.3105, 0.5123, 0.6263], | |||||
| [0.4997, 0.7707, 0.0595, 0.4188], | |||||
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |||||
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] | |||||
| ) at 0x7f203542feb0> | |||||
| """ | |||||
| def __setattr__(self, name: str, value: Sized): | |||||
| """setattr is only used to set data. | |||||
| The value must have the attribute of `__len__` and have the same length | |||||
| of `ListData`. | |||||
| """ | |||||
| if name in ("_metainfo_fields", "_data_fields"): | |||||
| if not hasattr(self, name): | |||||
| super().__setattr__(name, value) | |||||
| else: | |||||
| raise AttributeError( | |||||
| f"{name} has been used as a " | |||||
| "private attribute, which is immutable." | |||||
| ) | |||||
| else: | |||||
| assert isinstance(value, Sized), "value must contain `__len__` attribute" | |||||
| if len(self) > 0: | |||||
| assert len(value) == len(self), ( | |||||
| "The length of " | |||||
| f"values {len(value)} is " | |||||
| "not consistent with " | |||||
| "the length of this " | |||||
| ":obj:`ListData` " | |||||
| f"{len(self)}" | |||||
| ) | |||||
| super().__setattr__(name, value) | |||||
| __setitem__ = __setattr__ | |||||
| def __getitem__(self, item: IndexType) -> "ListData": | |||||
| """ | |||||
| Args: | |||||
| item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |||||
| :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |||||
| Get the corresponding values according to item. | |||||
| Returns: | |||||
| :obj:`ListData`: Corresponding values. | |||||
| """ | |||||
| assert isinstance(item, IndexType.__args__) | |||||
| if isinstance(item, list): | |||||
| item = np.array(item) | |||||
| if isinstance(item, np.ndarray): | |||||
| # The default int type of numpy is platform dependent, int32 for | |||||
| # windows and int64 for linux. `torch.Tensor` requires the index | |||||
| # should be int64, therefore we simply convert it to int64 here. | |||||
| # More details in https://github.com/numpy/numpy/issues/9464 | |||||
| item = item.astype(np.int64) if item.dtype == np.int32 else item | |||||
| item = torch.from_numpy(item) | |||||
| if isinstance(item, str): | |||||
| return getattr(self, item) | |||||
| if isinstance(item, int): | |||||
| if item >= len(self) or item < -len(self): # type:ignore | |||||
| raise IndexError(f"Index {item} out of range!") | |||||
| else: | |||||
| # keep the dimension | |||||
| item = slice(item, None, len(self)) | |||||
| new_data = self.__class__(metainfo=self.metainfo) | |||||
| if isinstance(item, torch.Tensor): | |||||
| assert item.dim() == 1, ( | |||||
| "Only support to get the" " values along the first dimension." | |||||
| ) | |||||
| if isinstance(item, BoolTypeTensor.__args__): | |||||
| assert len(item) == len(self), ( | |||||
| "The shape of the " | |||||
| "input(BoolTensor) " | |||||
| f"{len(item)} " | |||||
| "does not match the shape " | |||||
| "of the indexed tensor " | |||||
| "in results_field " | |||||
| f"{len(self)} at " | |||||
| "first dimension." | |||||
| ) | |||||
| for k, v in self.items(): | |||||
| if isinstance(v, torch.Tensor): | |||||
| new_data[k] = v[item] | |||||
| elif isinstance(v, np.ndarray): | |||||
| new_data[k] = v[item.cpu().numpy()] | |||||
| elif isinstance(v, (str, list, tuple)) or ( | |||||
| hasattr(v, "__getitem__") and hasattr(v, "cat") | |||||
| ): | |||||
| # convert to indexes from BoolTensor | |||||
| if isinstance(item, BoolTypeTensor.__args__): | |||||
| indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() | |||||
| else: | |||||
| indexes = item.cpu().numpy().tolist() | |||||
| slice_list = [] | |||||
| if indexes: | |||||
| for index in indexes: | |||||
| slice_list.append(slice(index, None, len(v))) | |||||
| else: | |||||
| slice_list.append(slice(None, 0, None)) | |||||
| r_list = [v[s] for s in slice_list] | |||||
| if isinstance(v, (str, list, tuple)): | |||||
| new_value = r_list[0] | |||||
| for r in r_list[1:]: | |||||
| new_value = new_value + r | |||||
| else: | |||||
| new_value = v.cat(r_list) | |||||
| new_data[k] = new_value | |||||
| else: | |||||
| raise ValueError( | |||||
| f"The type of `{k}` is `{type(v)}`, which has no " | |||||
| "attribute of `cat`, so it does not " | |||||
| "support slice with `bool`" | |||||
| ) | |||||
| else: | |||||
| # item is a slice | |||||
| for k, v in self.items(): | |||||
| new_data[k] = v[item] | |||||
| return new_data # type:ignore | |||||
| @staticmethod | |||||
| def cat(instances_list: List["ListData"]) -> "ListData": | |||||
| """Concat the instances of all :obj:`ListData` in the list. | |||||
| Note: To ensure that cat returns as expected, make sure that | |||||
| all elements in the list must have exactly the same keys. | |||||
| Args: | |||||
| instances_list (list[:obj:`ListData`]): A list | |||||
| of :obj:`ListData`. | |||||
| Returns: | |||||
| :obj:`ListData` | |||||
| """ | |||||
| assert all(isinstance(results, ListData) for results in instances_list) | |||||
| assert len(instances_list) > 0 | |||||
| if len(instances_list) == 1: | |||||
| return instances_list[0] | |||||
| # metainfo and data_fields must be exactly the | |||||
| # same for each element to avoid exceptions. | |||||
| field_keys_list = [instances.all_keys() for instances in instances_list] | |||||
| assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( | |||||
| set(itertools.chain(*field_keys_list)) | |||||
| ) == len(field_keys_list[0]), ( | |||||
| "There are different keys in " | |||||
| "`instances_list`, which may " | |||||
| "cause the cat operation " | |||||
| "to fail. Please make sure all " | |||||
| "elements in `instances_list` " | |||||
| "have the exact same key." | |||||
| ) | |||||
| new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) | |||||
| for k in instances_list[0].keys(): | |||||
| values = [results[k] for results in instances_list] | |||||
| v0 = values[0] | |||||
| if isinstance(v0, torch.Tensor): | |||||
| new_values = torch.cat(values, dim=0) | |||||
| elif isinstance(v0, np.ndarray): | |||||
| new_values = np.concatenate(values, axis=0) | |||||
| elif isinstance(v0, (str, list, tuple)): | |||||
| new_values = v0[:] | |||||
| for v in values[1:]: | |||||
| new_values += v | |||||
| elif hasattr(v0, "cat"): | |||||
| new_values = v0.cat(values) | |||||
| else: | |||||
| raise ValueError( | |||||
| f"The type of `{k}` is `{type(v0)}` which has no " | |||||
| "attribute of `cat`" | |||||
| ) | |||||
| new_data[k] = new_values | |||||
| return new_data # type:ignore | |||||
| def __len__(self) -> int: | |||||
| """int: The length of ListData.""" | |||||
| if len(self._data_fields) > 0: | |||||
| return len(self.values()[0]) | |||||
| else: | |||||
| return 0 | |||||