Add abstract data interface to bridge, dataset, evaluation and learning.pull/1/head
| @@ -1,52 +1,64 @@ | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Tuple | |||
| from typing import Any, List, Optional, Tuple, Union | |||
| from ..learning import ABLModel | |||
| from ..reasoning import ReasonerBase | |||
| from ..structures import ListData | |||
| DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] | |||
| class BaseBridge(metaclass=ABCMeta): | |||
| def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: | |||
| class BaseBridge(metaclass=ABCMeta): | |||
| def __init__(self, model: ABLModel, reasoner: ReasonerBase) -> None: | |||
| if not isinstance(model, ABLModel): | |||
| raise TypeError("Expected an ABLModel") | |||
| if not isinstance(abducer, ReasonerBase): | |||
| raise TypeError("Expected an ReasonerBase") | |||
| raise TypeError( | |||
| "Expected an instance of ABLModel, but received type: {}".format( | |||
| type(model) | |||
| ) | |||
| ) | |||
| if not isinstance(reasoner, ReasonerBase): | |||
| raise TypeError( | |||
| "Expected an instance of ReasonerBase, but received type: {}".format( | |||
| type(reasoner) | |||
| ) | |||
| ) | |||
| self.model = model | |||
| self.abducer = abducer | |||
| self.reasoner = reasoner | |||
| @abstractmethod | |||
| def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
| def predict( | |||
| self, data_samples: ListData | |||
| ) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
| """Placeholder for predict labels from input.""" | |||
| pass | |||
| @abstractmethod | |||
| def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||
| def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
| """Placeholder for abduce pseudo labels.""" | |||
| pass | |||
| @abstractmethod | |||
| def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: | |||
| def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
| """Placeholder for map label space to symbol space.""" | |||
| pass | |||
| @abstractmethod | |||
| def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||
| def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||
| """Placeholder for map symbol space to label space.""" | |||
| pass | |||
| @abstractmethod | |||
| def train(self, train_data): | |||
| def train(self, train_data: Union[ListData, DataSet]): | |||
| """Placeholder for train loop of ABductive Learning.""" | |||
| pass | |||
| @abstractmethod | |||
| def test(self, test_data): | |||
| def valid(self, valid_data: Union[ListData, DataSet]) -> None: | |||
| """Placeholder for model test.""" | |||
| pass | |||
| @abstractmethod | |||
| def valid(self, valid_data): | |||
| def test(self, test_data: Union[ListData, DataSet]) -> None: | |||
| """Placeholder for model validation.""" | |||
| pass | |||
| @@ -1,104 +1,119 @@ | |||
| from ..learning import ABLModel | |||
| from ..reasoning import ReasonerBase | |||
| from ..evaluation import BaseMetric | |||
| from .base_bridge import BaseBridge | |||
| from typing import List, Union, Any, Tuple, Dict, Optional | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Optional, Tuple, Union | |||
| from numpy import ndarray | |||
| from torch.utils.data import DataLoader | |||
| from ..dataset import BridgeDataset | |||
| from ..utils.logger import print_log | |||
| from ..evaluation import BaseMetric | |||
| from ..learning import ABLModel | |||
| from ..reasoning import ReasonerBase | |||
| from ..structures import ListData | |||
| from ..utils import print_log | |||
| from .base_bridge import BaseBridge, DataSet | |||
| class SimpleBridge(BaseBridge): | |||
| def __init__( | |||
| self, | |||
| model: ABLModel, | |||
| abducer: ReasonerBase, | |||
| reasoner: ReasonerBase, | |||
| metric_list: List[BaseMetric], | |||
| ) -> None: | |||
| super().__init__(model, abducer) | |||
| super().__init__(model, reasoner) | |||
| self.metric_list = metric_list | |||
| def predict(self, X) -> Tuple[List[List[Any]], ndarray]: | |||
| pred_res = self.model.predict(X) | |||
| pred_idx, pred_prob = pred_res["label"], pred_res["prob"] | |||
| return pred_idx, pred_prob | |||
| # TODO: add reasoner.mapping to the property of SimpleBridge | |||
| def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
| self.model.predict(data_samples) | |||
| return data_samples.pred_idx, data_samples.pred_prob | |||
| def abduce_pseudo_label( | |||
| self, | |||
| pred_prob: ndarray, | |||
| pred_pseudo_label: List[List[Any]], | |||
| Y: List[Any], | |||
| data_samples: ListData, | |||
| max_revision: int = -1, | |||
| require_more_revision: int = 0, | |||
| ) -> List[List[Any]]: | |||
| return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) | |||
| self.reasoner.batch_abduce(data_samples, max_revision, require_more_revision) | |||
| return data_samples.abduced_pseudo_label | |||
| def idx_to_pseudo_label( | |||
| self, idx: List[List[Any]], mapping: Dict = None | |||
| self, data_samples: ListData, mapping: Optional[Dict] = None | |||
| ) -> List[List[Any]]: | |||
| if mapping is None: | |||
| mapping = self.abducer.mapping | |||
| return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] | |||
| mapping = self.reasoner.mapping | |||
| pred_idx = data_samples.pred_idx | |||
| data_samples.pred_pseudo_label = [ | |||
| [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx | |||
| ] | |||
| return data_samples.pred_pseudo_label | |||
| def pseudo_label_to_idx( | |||
| self, pseudo_label: List[List[Any]], mapping: Dict = None | |||
| self, data_samples: ListData, mapping: Optional[Dict] = None | |||
| ) -> List[List[Any]]: | |||
| if mapping is None: | |||
| mapping = self.abducer.remapping | |||
| return [ | |||
| [mapping[_pseudo_label] for _pseudo_label in sub_list] | |||
| for sub_list in pseudo_label | |||
| mapping = self.reasoner.remapping | |||
| abduced_idx = [ | |||
| [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] | |||
| for sub_list in data_samples.abduced_pseudo_label | |||
| ] | |||
| data_samples.abduced_idx = abduced_idx | |||
| return data_samples.abduced_idx | |||
| def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: | |||
| data_samples = ListData() | |||
| data_samples.X = X | |||
| data_samples.gt_pseudo_label = gt_pseudo_label | |||
| data_samples.Y = Y | |||
| return data_samples | |||
| def train( | |||
| self, | |||
| train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], | |||
| epochs: int = 50, | |||
| batch_size: Union[int, float] = -1, | |||
| train_data: Union[ListData, DataSet], | |||
| loops: int = 50, | |||
| segment_size: Union[int, float] = -1, | |||
| eval_interval: int = 1, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| ): | |||
| dataset = BridgeDataset(*train_data) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=batch_size, | |||
| collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
| ) | |||
| for epoch in range(epochs): | |||
| for seg_idx, (X, Z, Y) in enumerate(data_loader): | |||
| pred_idx, pred_prob = self.predict(X) | |||
| pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||
| abduced_pseudo_label = self.abduce_pseudo_label( | |||
| pred_prob, pred_pseudo_label, Y | |||
| ) | |||
| abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) | |||
| loss = self.model.train(X, abduced_label) | |||
| if isinstance(train_data, ListData): | |||
| data_samples = train_data | |||
| else: | |||
| data_samples = self.data_preprocess(*train_data) | |||
| for loop in range(loops): | |||
| for seg_idx in range((len(data_samples) - 1) // segment_size + 1): | |||
| sub_data_samples = data_samples[ | |||
| seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
| ] | |||
| self.predict(sub_data_samples) | |||
| self.idx_to_pseudo_label(sub_data_samples) | |||
| self.abduce_pseudo_label(sub_data_samples) | |||
| self.pseudo_label_to_idx(sub_data_samples) | |||
| loss = self.model.train(sub_data_samples) | |||
| print_log( | |||
| f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", | |||
| f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", | |||
| logger="current", | |||
| ) | |||
| if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: | |||
| print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") | |||
| if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||
| print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") | |||
| self.valid(train_data) | |||
| def _valid(self, data_loader): | |||
| for X, Z, Y in data_loader: | |||
| pred_idx, pred_prob = self.predict(X) | |||
| pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||
| data_samples = dict( | |||
| pred_idx=pred_idx, | |||
| pred_prob=pred_prob, | |||
| pred_pseudo_label=pred_pseudo_label, | |||
| gt_pseudo_label=Z, | |||
| Y=Y, | |||
| logic_forward=self.abducer.kb.logic_forward, | |||
| ) | |||
| for metric in self.metric_list: | |||
| metric.process(data_samples) | |||
| if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||
| print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||
| self.model.save( | |||
| save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") | |||
| ) | |||
| def _valid(self, data_samples: ListData) -> None: | |||
| self.predict(data_samples) | |||
| self.idx_to_pseudo_label(data_samples) | |||
| for metric in self.metric_list: | |||
| metric.process(data_samples) | |||
| res = dict() | |||
| for metric in self.metric_list: | |||
| @@ -108,14 +123,12 @@ class SimpleBridge(BaseBridge): | |||
| msg += k + f": {v:.3f} " | |||
| print_log(msg, logger="current") | |||
| def valid(self, valid_data, batch_size=1000): | |||
| dataset = BridgeDataset(*valid_data) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=batch_size, | |||
| collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
| ) | |||
| self._valid(data_loader) | |||
| def test(self, test_data, batch_size=1000): | |||
| self.valid(test_data, batch_size) | |||
| def valid(self, valid_data: Union[ListData, DataSet]) -> None: | |||
| if not isinstance(valid_data, ListData): | |||
| data_samples = self.data_preprocess(*valid_data) | |||
| else: | |||
| data_samples = valid_data | |||
| self._valid(data_samples) | |||
| def test(self, test_data: Union[ListData, DataSet]) -> None: | |||
| self.valid(test_data) | |||
| @@ -1,3 +1,4 @@ | |||
| from .bridge_dataset import BridgeDataset | |||
| from .classification_dataset import ClassificationDataset | |||
| from .regression_dataset import RegressionDataset | |||
| from .prediction_dataset import PredictionDataset | |||
| from .regression_dataset import RegressionDataset | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import Any, List, Tuple | |||
| from torch.utils.data import Dataset | |||
| from typing import List, Any, Tuple | |||
| class BridgeDataset(Dataset): | |||
| @@ -1,6 +1,7 @@ | |||
| from typing import Any, Callable, List, Tuple | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| from typing import List, Any, Tuple, Callable | |||
| class ClassificationDataset(Dataset): | |||
| @@ -0,0 +1,56 @@ | |||
| from typing import Any, Callable, List, Tuple | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| class PredictionDataset(Dataset): | |||
| def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||
| """ | |||
| Initialize the dataset used for classification task. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| The input data. | |||
| transform : Callable[..., Any], optional | |||
| A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
| """ | |||
| if not isinstance(X, list): | |||
| raise ValueError("X should be of type list.") | |||
| self.X = X | |||
| self.transform = transform | |||
| def __len__(self) -> int: | |||
| """ | |||
| Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||
| """ | |||
| Get the item at the given index. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to get. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, torch.Tensor] | |||
| A tuple containing the object and its label. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| x = self.X[index] | |||
| if self.transform is not None: | |||
| x = self.transform(x) | |||
| return x | |||
| @@ -1,6 +1,7 @@ | |||
| from typing import Any, List, Tuple | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| from typing import List, Any, Tuple | |||
| class RegressionDataset(Dataset): | |||
| @@ -1,3 +1,3 @@ | |||
| from .base_metric import BaseMetric | |||
| from .symbol_metric import SymbolMetric | |||
| from .semantics_metric import SemanticsMetric | |||
| from .symbol_metric import SymbolMetric | |||
| @@ -1,8 +1,8 @@ | |||
| import logging | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Optional, Sequence | |||
| from ..utils import print_log | |||
| import logging | |||
| from ..utils import print_log | |||
| class BaseMetric(metaclass=ABCMeta): | |||
| @@ -1,25 +1,24 @@ | |||
| from typing import Optional, Sequence | |||
| from ..reasoning import KBBase | |||
| from .base_metric import BaseMetric | |||
| class ABLMetric(): | |||
| pass | |||
| class SemanticsMetric(BaseMetric): | |||
| def __init__(self, prefix: Optional[str] = None) -> None: | |||
| def __init__(self, kb: KBBase = None, prefix: Optional[str] = None) -> None: | |||
| super().__init__(prefix) | |||
| self.kb = kb | |||
| def process(self, data_samples: Sequence[dict]) -> None: | |||
| pred_pseudo_label = data_samples["pred_pseudo_label"] | |||
| gt_Y = data_samples["Y"] | |||
| logic_forward = data_samples["logic_forward"] | |||
| for pred_z, y in zip(pred_pseudo_label, gt_Y): | |||
| if logic_forward(pred_z) == y: | |||
| pred_pseudo_label_list = data_samples.pred_pseudo_label | |||
| y_list = data_samples.Y | |||
| for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list): | |||
| if self.kb._check_equal(self.kb.logic_forward(pred_pseudo_label), y): | |||
| self.results.append(1) | |||
| else: | |||
| self.results.append(0) | |||
| def compute_metrics(self, results: list) -> dict: | |||
| metrics = dict() | |||
| metrics["semantics_accuracy"] = sum(results) / len(results) | |||
| return metrics | |||
| return metrics | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Optional, Sequence, Callable | |||
| from typing import Optional, Sequence | |||
| from .base_metric import BaseMetric | |||
| @@ -7,9 +8,9 @@ class SymbolMetric(BaseMetric): | |||
| super().__init__(prefix) | |||
| def process(self, data_samples: Sequence[dict]) -> None: | |||
| pred_pseudo_label = data_samples["pred_pseudo_label"] | |||
| pred_pseudo_label = data_samples.pred_pseudo_label | |||
| gt_pseudo_label = data_samples["gt_pseudo_label"] | |||
| gt_pseudo_label = data_samples.gt_pseudo_label | |||
| if not len(pred_pseudo_label) == len(gt_pseudo_label): | |||
| raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") | |||
| @@ -10,8 +10,10 @@ | |||
| # | |||
| # ================================================================# | |||
| import pickle | |||
| from utils import flatten, reform_idx | |||
| from typing import List, Any, Optional | |||
| from typing import Any, Dict | |||
| from ..structures import ListData | |||
| from ..utils import reform_list | |||
| class ABLModel: | |||
| @@ -30,7 +32,7 @@ class ABLModel: | |||
| Methods | |||
| ------- | |||
| predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict | |||
| predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict | |||
| Predict the labels and probabilities for the given data. | |||
| valid(X: List[List[Any]], Y: List[Any]) -> float | |||
| Calculate the accuracy score for the given data. | |||
| @@ -42,20 +44,13 @@ class ABLModel: | |||
| Load the model from a file. | |||
| """ | |||
| def __init__(self, base_model) -> None: | |||
| self.classifier_list = [] | |||
| self.classifier_list.append(base_model) | |||
| def __init__(self, base_model: Any) -> None: | |||
| if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): | |||
| raise NotImplementedError("The base_model should implement fit and predict methods.") | |||
| if not ( | |||
| hasattr(base_model, "fit") | |||
| and hasattr(base_model, "predict") | |||
| and hasattr(base_model, "score") | |||
| ): | |||
| raise NotImplementedError( | |||
| "base_model should have fit, predict and score methods." | |||
| ) | |||
| self.base_model = base_model | |||
| def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: | |||
| def predict(self, data_samples: ListData) -> Dict: | |||
| """ | |||
| Predict the labels and probabilities for the given data. | |||
| @@ -63,53 +58,29 @@ class ABLModel: | |||
| ---------- | |||
| X : List[List[Any]] | |||
| The data to predict on. | |||
| mapping : Optional[dict], optional | |||
| A mapping dictionary to map labels to their original values, by default None. | |||
| Returns | |||
| ------- | |||
| dict | |||
| A dictionary containing the predicted labels and probabilities. | |||
| """ | |||
| model = self.classifier_list[0] | |||
| data_X = flatten(X) | |||
| model = self.base_model | |||
| data_X = data_samples.flatten("X") | |||
| if hasattr(model, "predict_proba"): | |||
| prob = model.predict_proba(X=data_X) | |||
| label = prob.argmax(axis=1) | |||
| prob = reform_idx(prob, X) | |||
| prob = reform_list(prob, data_samples.X) | |||
| else: | |||
| prob = None | |||
| label = model.predict(X=data_X) | |||
| label = reform_list(label, data_samples.X) | |||
| if mapping is not None: | |||
| label = [mapping[y] for y in label] | |||
| label = reform_idx(label, X) | |||
| data_samples.pred_idx = label | |||
| data_samples.pred_prob = prob | |||
| return {"label": label, "prob": prob} | |||
| def valid(self, X: List[List[Any]], Y: List[Any]) -> float: | |||
| """ | |||
| Calculate the accuracy for the given data. | |||
| Parameters | |||
| ---------- | |||
| X : List[List[Any]] | |||
| The data to calculate the accuracy on. | |||
| Y : List[Any] | |||
| The true labels for the given data. | |||
| Returns | |||
| ------- | |||
| float | |||
| The accuracy score for the given data. | |||
| """ | |||
| data_X = flatten(X) | |||
| data_Y = flatten(Y) | |||
| score = self.classifier_list[0].score(X=data_X, y=data_Y) | |||
| return score | |||
| def train(self, X: List[List[Any]], Y: List[Any]) -> float: | |||
| def train(self, data_samples: ListData) -> float: | |||
| """ | |||
| Train the model on the given data. | |||
| @@ -125,29 +96,30 @@ class ABLModel: | |||
| float | |||
| The loss value of the trained model. | |||
| """ | |||
| data_X = flatten(X) | |||
| data_Y = flatten(Y) | |||
| return self.classifier_list[0].fit(X=data_X, y=data_Y) | |||
| data_X = data_samples.flatten("X") | |||
| data_y = data_samples.flatten("abduced_idx") | |||
| return self.base_model.fit(X=data_X, y=data_y) | |||
| def _model_operation(self, operation: str, *args, **kwargs): | |||
| model = self.classifier_list[0] | |||
| model = self.base_model | |||
| if hasattr(model, operation): | |||
| method = getattr(model, operation) | |||
| method(*args, **kwargs) | |||
| else: | |||
| try: | |||
| if not f"{operation}_path" in kwargs.keys(): | |||
| raise ValueError(f"'{operation}_path' should not be None") | |||
| if operation == "save": | |||
| with open(kwargs["save_path"], 'wb') as file: | |||
| pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
| elif operation == "load": | |||
| with open(kwargs["load_path"], 'rb') as file: | |||
| self.classifier_list[0] = pickle.load(file) | |||
| except: | |||
| raise NotImplementedError( | |||
| f"{type(model).__name__} object doesn't have the {operation} method" | |||
| ) | |||
| if not f"{operation}_path" in kwargs.keys(): | |||
| raise ValueError(f"'{operation}_path' should not be None") | |||
| else: | |||
| try: | |||
| if operation == "save": | |||
| with open(kwargs["save_path"], "wb") as file: | |||
| pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
| elif operation == "load": | |||
| with open(kwargs["load_path"], "rb") as file: | |||
| self.base_model = pickle.load(file) | |||
| except: | |||
| raise NotImplementedError( | |||
| f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." | |||
| ) | |||
| def save(self, *args, **kwargs) -> None: | |||
| """ | |||
| @@ -10,14 +10,16 @@ | |||
| # | |||
| # ================================================================# | |||
| import torch | |||
| import os | |||
| import logging | |||
| from typing import Any, Callable, List, Optional, T, Tuple | |||
| import numpy | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| from ..utils.logger import print_log | |||
| from ..dataset import ClassificationDataset | |||
| import os | |||
| from typing import List, Any, T, Optional, Callable, Tuple | |||
| from ..dataset import ClassificationDataset, PredictionDataset | |||
| from ..utils.logger import print_log | |||
| class BasicNN: | |||
| @@ -64,7 +66,8 @@ class BasicNN: | |||
| num_workers: int = 0, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| transform: Callable[..., Any] = None, | |||
| train_transform: Callable[..., Any] = None, | |||
| test_transform: Callable[..., Any] = None, | |||
| collate_fn: Callable[[List[T]], Any] = None, | |||
| ) -> None: | |||
| self.model = model.to(device) | |||
| @@ -77,10 +80,19 @@ class BasicNN: | |||
| self.num_workers = num_workers | |||
| self.save_interval = save_interval | |||
| self.save_dir = save_dir | |||
| self.transform = transform | |||
| self.train_transform = train_transform | |||
| self.test_transform = test_transform | |||
| self.collate_fn = collate_fn | |||
| def _fit(self, data_loader) -> float: | |||
| if self.train_transform is not None and self.test_transform is None: | |||
| print_log( | |||
| "Transform used in the training phase will be used in prediction.", | |||
| "current", | |||
| level=logging.WARNING, | |||
| ) | |||
| self.test_transform = self.train_transform | |||
| def _fit(self, data_loader: DataLoader) -> float: | |||
| """ | |||
| Internal method to fit the model on data for n epochs, with early stopping. | |||
| @@ -99,9 +111,7 @@ class BasicNN: | |||
| loss_value = self.train_epoch(data_loader) | |||
| if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: | |||
| if self.save_dir is None: | |||
| raise ValueError( | |||
| "save_dir should not be None if save_interval is not None." | |||
| ) | |||
| raise ValueError("save_dir should not be None if save_interval is not None.") | |||
| self.save(epoch + 1) | |||
| if self.stop_loss is not None and loss_value < self.stop_loss: | |||
| break | |||
| @@ -170,7 +180,7 @@ class BasicNN: | |||
| return total_loss / total_num | |||
| def _predict(self, data_loader) -> torch.Tensor: | |||
| def _predict(self, data_loader: DataLoader) -> torch.Tensor: | |||
| """ | |||
| Internal method to predict the outputs given a DataLoader. | |||
| @@ -191,16 +201,14 @@ class BasicNN: | |||
| with torch.no_grad(): | |||
| results = [] | |||
| for data, _ in data_loader: | |||
| for data in data_loader: | |||
| data = data.to(device) | |||
| out = model(data) | |||
| results.append(out) | |||
| return torch.cat(results, axis=0) | |||
| def predict( | |||
| self, data_loader: DataLoader = None, X: List[Any] = None | |||
| ) -> numpy.ndarray: | |||
| def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||
| """ | |||
| Predict the class of the input data. | |||
| @@ -218,12 +226,16 @@ class BasicNN: | |||
| """ | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X) | |||
| dataset = PredictionDataset(X, self.test_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| return self._predict(data_loader).argmax(axis=1).cpu().numpy() | |||
| def predict_proba( | |||
| self, data_loader: DataLoader = None, X: List[Any] = None | |||
| ) -> numpy.ndarray: | |||
| def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||
| """ | |||
| Predict the probability of each class for the input data. | |||
| @@ -241,10 +253,16 @@ class BasicNN: | |||
| """ | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X) | |||
| dataset = PredictionDataset(X, self.test_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| return self._predict(data_loader).softmax(axis=1).cpu().numpy() | |||
| def _score(self, data_loader) -> Tuple[float, float]: | |||
| def _score(self, data_loader: DataLoader) -> Tuple[float, float]: | |||
| """ | |||
| Internal method to compute loss and accuracy for the data provided through a DataLoader. | |||
| @@ -313,16 +331,10 @@ class BasicNN: | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X, y) | |||
| mean_loss, accuracy = self._score(data_loader) | |||
| print_log( | |||
| f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current" | |||
| ) | |||
| print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") | |||
| return accuracy | |||
| def _data_loader( | |||
| self, | |||
| X: List[Any], | |||
| y: List[int] = None, | |||
| ) -> DataLoader: | |||
| def _data_loader(self, X: List[Any], y: List[int] = None, shuffle: bool = True) -> DataLoader: | |||
| """ | |||
| Generate a DataLoader for user-provided input and target data. | |||
| @@ -346,11 +358,11 @@ class BasicNN: | |||
| if not (len(y) == len(X)): | |||
| raise ValueError("X and y should have equal length.") | |||
| dataset = ClassificationDataset(X, y, transform=self.transform) | |||
| dataset = ClassificationDataset(X, y, transform=self.train_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| shuffle=True, | |||
| shuffle=shuffle, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| @@ -368,14 +380,13 @@ class BasicNN: | |||
| The path to save the model, by default None. | |||
| """ | |||
| if self.save_dir is None and save_path is None: | |||
| raise ValueError( | |||
| "'save_dir' and 'save_path' should not be None simultaneously." | |||
| ) | |||
| raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.") | |||
| if save_path is None: | |||
| save_path = os.path.join( | |||
| self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth" | |||
| ) | |||
| if save_path is not None: | |||
| if not os.path.exists(os.path.dirname(save_path)): | |||
| os.makedirs(os.path.dirname(save_path)) | |||
| else: | |||
| save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth") | |||
| if not os.path.exists(self.save_dir): | |||
| os.makedirs(self.save_dir) | |||
| @@ -1,2 +1,2 @@ | |||
| from .reasoner import ReasonerBase | |||
| from .kb import KBBase, prolog_KB | |||
| from .kb import KBBase, GroundKB, PrologKB | |||
| from .reasoner import ReasonerBase | |||
| @@ -9,7 +9,8 @@ from functools import lru_cache | |||
| import numpy as np | |||
| import pyswip | |||
| from abl.utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable | |||
| from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable, restore_from_hashable | |||
| from ..utils.cache import abl_cache | |||
| class KBBase(ABC): | |||
| @@ -21,35 +22,46 @@ class KBBase(ABC): | |||
| pseudo_label_list : list | |||
| List of possible pseudo labels. | |||
| max_err : float, optional | |||
| The upper tolerance limit when comparing the similarity between a candidate's logical | |||
| result. This is only applicable when the logical result is of a numerical type. | |||
| This is particularly relevant for regression problems where exact matches might not be | |||
| feasible. Defaults to 1e-10. | |||
| The upper tolerance limit when comparing the similarity between a candidate's logical | |||
| result. This is only applicable when the logical result is of a numerical type. | |||
| This is particularly relevant for regression problems where exact matches might not be | |||
| feasible. Defaults to 1e-10. | |||
| use_cache : bool, optional | |||
| Whether to use a cache for previously abduced candidates to speed up subsequent | |||
| Whether to use a cache for previously abduced candidates to speed up subsequent | |||
| operations. Defaults to True. | |||
| Notes | |||
| ----- | |||
| Users should inherit from this base class to build their own knowledge base. For the | |||
| user-build KB (an inherited subclass), it's only required for the user to provide the | |||
| `pseudo_label_list` and override the `logic_forward` function (specifying how to | |||
| perform logical reasoning). After that, other operations (e.g. how to perform abductive | |||
| reasoning) will be automatically set up. | |||
| Users should inherit from this base class to build their own knowledge base. For the | |||
| user-build KB (an inherited subclass), it's only required for the user to provide the | |||
| `pseudo_label_list` and override the `logic_forward` function (specifying how to | |||
| perform logical reasoning). After that, other operations (e.g. how to perform abductive | |||
| reasoning) will be automatically set up. | |||
| """ | |||
| def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list, | |||
| max_err=1e-10, | |||
| use_cache=True, | |||
| key_func=to_hashable, | |||
| max_cache_size=4096, | |||
| ): | |||
| if not isinstance(pseudo_label_list, list): | |||
| raise TypeError("pseudo_label_list should be list") | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.max_err = max_err | |||
| self.use_cache = use_cache | |||
| self.use_cache = use_cache | |||
| self.key_func = key_func | |||
| self.max_cache_size = max_cache_size | |||
| @abstractmethod | |||
| def logic_forward(self, pseudo_label): | |||
| """ | |||
| How to perform (deductive) logical reasoning, i.e. matching each pseudo label to | |||
| How to perform (deductive) logical reasoning, i.e. matching each pseudo label to | |||
| their logical result. Users are required to provide this. | |||
| Parameters | |||
| ---------- | |||
| pred_pseudo_label : List[Any] | |||
| @@ -70,23 +82,22 @@ class KBBase(ABC): | |||
| max_revision_num : int | |||
| The upper limit on the number of revisions. | |||
| require_more_revision : int, optional | |||
| Specifies additional number of revisions permitted beyond the minimum required. | |||
| Specifies additional number of revisions permitted beyond the minimum required. | |||
| Defaults to 0. | |||
| Returns | |||
| ------- | |||
| List[List[Any]] | |||
| A list of candidates, i.e. revised pseudo labels that are consistent with the | |||
| A list of candidates, i.e. revised pseudo labels that are consistent with the | |||
| knowledge base. | |||
| """ | |||
| if self.use_cache: | |||
| return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), | |||
| to_hashable(y), | |||
| max_revision_num, require_more_revision) | |||
| else: | |||
| return self._abduce_by_search(pred_pseudo_label, y, | |||
| max_revision_num, require_more_revision) | |||
| # if self.use_cache: | |||
| # return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), | |||
| # to_hashable(y), | |||
| # max_revision_num, require_more_revision) | |||
| # else: | |||
| return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) | |||
| def _check_equal(self, logic_result, y): | |||
| """ | |||
| Check whether the logical result of a candidate is equal to the ground truth | |||
| @@ -94,12 +105,12 @@ class KBBase(ABC): | |||
| """ | |||
| if logic_result == None: | |||
| return False | |||
| if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): | |||
| return abs(logic_result - y) <= self.max_err | |||
| else: | |||
| return logic_result == y | |||
| def revise_at_idx(self, pred_pseudo_label, y, revision_idx): | |||
| """ | |||
| Revise the predicted pseudo label at specified index positions. | |||
| @@ -125,7 +136,7 @@ class KBBase(ABC): | |||
| def _revision(self, revision_num, pred_pseudo_label, y): | |||
| """ | |||
| For a specified number of pseudo label to revise, iterate through all possible | |||
| For a specified number of pseudo label to revise, iterate through all possible | |||
| indices to find any candidates that are consistent with the knowledge base. | |||
| """ | |||
| new_candidates = [] | |||
| @@ -136,12 +147,13 @@ class KBBase(ABC): | |||
| new_candidates.extend(candidates) | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): | |||
| @abl_cache() | |||
| def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): | |||
| """ | |||
| Perform abductive reasoning by exhastive search. Specifically, begin with 0 and | |||
| continuously increase the number of pseudo labels to revise, until candidates | |||
| Perform abductive reasoning by exhastive search. Specifically, begin with 0 and | |||
| continuously increase the number of pseudo labels to revise, until candidates | |||
| that are consistent with the knowledge base are found. | |||
| Parameters | |||
| ---------- | |||
| pred_pseudo_label : List[Any] | |||
| @@ -151,16 +163,16 @@ class KBBase(ABC): | |||
| max_revision_num : int | |||
| The upper limit on the number of revisions. | |||
| require_more_revision : int | |||
| If larger than 0, then after having found any candidates consistent with the | |||
| knowledge base, continue to increase the number pseudo labels to revise to | |||
| If larger than 0, then after having found any candidates consistent with the | |||
| knowledge base, continue to increase the number pseudo labels to revise to | |||
| get more possible consistent candidates. | |||
| Returns | |||
| ------- | |||
| List[List[Any]] | |||
| A list of candidates, i.e. revised pseudo label that are consistent with the | |||
| A list of candidates, i.e. revised pseudo label that are consistent with the | |||
| knowledge base. | |||
| """ | |||
| """ | |||
| candidates = [] | |||
| for revision_num in range(len(pred_pseudo_label) + 1): | |||
| if revision_num == 0 and self._check_equal(self.logic_forward(pred_pseudo_label), y): | |||
| @@ -173,20 +185,22 @@ class KBBase(ABC): | |||
| if revision_num >= max_revision_num: | |||
| return [] | |||
| for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): | |||
| for revision_num in range( | |||
| min_revision_num + 1, min_revision_num + require_more_revision + 1 | |||
| ): | |||
| if revision_num > max_revision_num: | |||
| return candidates | |||
| candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) | |||
| return candidates | |||
| @lru_cache(maxsize=4096) | |||
| def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): | |||
| """ | |||
| `_abduce_by_search` with cache. | |||
| """ | |||
| pred_pseudo_label = restore_from_hashable(pred_pseudo_label) | |||
| y = restore_from_hashable(y) | |||
| return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) | |||
| # @abl_cache(max_size=4096) | |||
| # def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): | |||
| # """ | |||
| # `_abduce_by_search` with cache. | |||
| # """ | |||
| # pred_pseudo_label = restore_from_hashable(pred_pseudo_label) | |||
| # y = restore_from_hashable(y) | |||
| # return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) | |||
| def __repr__(self): | |||
| return ( | |||
| @@ -195,13 +209,13 @@ class KBBase(ABC): | |||
| f"max_err={self.max_err!r}, " | |||
| f"use_cache={self.use_cache!r}." | |||
| ) | |||
| class GroundKB(KBBase): | |||
| """ | |||
| Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon | |||
| class initialization, storing all potential candidates along with their respective | |||
| logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. | |||
| Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon | |||
| class initialization, storing all potential candidates along with their respective | |||
| logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. | |||
| Parameters | |||
| ---------- | |||
| @@ -211,15 +225,16 @@ class GroundKB(KBBase): | |||
| List of possible lengths of pseudo label. | |||
| max_err : float, optional | |||
| Refer to class `KBBase`. | |||
| Notes | |||
| ----- | |||
| Users can also inherit from this class to build their own knowledge base. Similar | |||
| to `KBBase`, users are only required to provide the `pseudo_label_list` and override | |||
| Users can also inherit from this class to build their own knowledge base. Similar | |||
| to `KBBase`, users are only required to provide the `pseudo_label_list` and override | |||
| the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. | |||
| After that, other operations (e.g. auto-construction of GKB, and how to perform | |||
| After that, other operations (e.g. auto-construction of GKB, and how to perform | |||
| abductive reasoning) will be automatically set up. | |||
| """ | |||
| def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| if not isinstance(GKB_len_list, list): | |||
| @@ -229,7 +244,6 @@ class GroundKB(KBBase): | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| @@ -259,21 +273,21 @@ class GroundKB(KBBase): | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| if Y and isinstance(Y[0], (int, float)): | |||
| if Y and isinstance(Y[0], (int, float)): | |||
| X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||
| return X, Y | |||
| def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): | |||
| """ | |||
| Perform abductive reasoning by directly retrieving consistent candidates from | |||
| the prebuilt GKB. In this way, the time-consuming exhaustive search can be | |||
| Perform abductive reasoning by directly retrieving consistent candidates from | |||
| the prebuilt GKB. In this way, the time-consuming exhaustive search can be | |||
| avoided. | |||
| This is an overridden function. For more information about the parameters and | |||
| This is an overridden function. For more information about the parameters and | |||
| returns, refer to the function of the same name in class `KBBase`. | |||
| """ | |||
| if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list: | |||
| return [] | |||
| all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) | |||
| if len(all_candidates) == 0: | |||
| return [] | |||
| @@ -284,29 +298,30 @@ class GroundKB(KBBase): | |||
| idxs = np.where(cost_list <= revision_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates | |||
| def _find_candidate_GKB(self, pred_pseudo_label, y): | |||
| """ | |||
| Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, | |||
| return all candidates whose logical results fall within the | |||
| Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, | |||
| return all candidates whose logical results fall within the | |||
| [y - max_err, y + max_err] range. | |||
| """ | |||
| if isinstance(y, (int, float)): | |||
| potential_candidates = self.GKB[len(pred_pseudo_label)] | |||
| key_list = list(potential_candidates.keys()) | |||
| low_key = bisect.bisect_left(key_list, y - self.max_err) | |||
| high_key = bisect.bisect_right(key_list, y + self.max_err) | |||
| all_candidates = [candidate | |||
| for key in key_list[low_key:high_key] | |||
| for candidate in potential_candidates[key]] | |||
| all_candidates = [ | |||
| candidate | |||
| for key in key_list[low_key:high_key] | |||
| for candidate in potential_candidates[key] | |||
| ] | |||
| return all_candidates | |||
| else: | |||
| return self.GKB[len(pred_pseudo_label)][y] | |||
| def __repr__(self): | |||
| return ( | |||
| f"{self.__class__.__name__} is a KB with " | |||
| @@ -321,78 +336,80 @@ class GroundKB(KBBase): | |||
| class PrologKB(KBBase): | |||
| """ | |||
| Knowledge base provided by a Prolog (.pl) file. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label_list : list | |||
| Refer to class `KBBase`. | |||
| pl_file : | |||
| Prolog file containing the KB. | |||
| pl_file : | |||
| Prolog file containing the KB. | |||
| max_err : float, optional | |||
| Refer to class `KBBase`. | |||
| Notes | |||
| ----- | |||
| Users can instantiate this class to build their own knowledge base. During the | |||
| Users can instantiate this class to build their own knowledge base. During the | |||
| instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`. | |||
| To use the default logic forward and abductive reasoning methods in this class, in the | |||
| Prolog (.pl) file, there needs to be a rule which is strictly formatted as | |||
| To use the default logic forward and abductive reasoning methods in this class, in the | |||
| Prolog (.pl) file, there needs to be a rule which is strictly formatted as | |||
| `logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`. | |||
| For specifics, refer to the `logic_forward` and `get_query_string` functions in this | |||
| For specifics, refer to the `logic_forward` and `get_query_string` functions in this | |||
| class. Users are also welcome to override related functions for more flexible support. | |||
| """ | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list) | |||
| self.pl_file = pl_file | |||
| self.prolog = pyswip.Prolog() | |||
| if not os.path.exists(self.pl_file): | |||
| raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") | |||
| self.prolog.consult(self.pl_file) | |||
| def logic_forward(self, pseudo_labels): | |||
| """ | |||
| Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the | |||
| returned `Res` as the logical results. To use this default function, there must be | |||
| a Prolog `log_forward` method in the pl file to perform logical. reasoning. | |||
| Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the | |||
| returned `Res` as the logical results. To use this default function, there must be | |||
| a Prolog `log_forward` method in the pl file to perform logical. reasoning. | |||
| Otherwise, users would override this function. | |||
| """ | |||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] | |||
| if result == 'true': | |||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"] | |||
| if result == "true": | |||
| return True | |||
| elif result == 'false': | |||
| elif result == "false": | |||
| return False | |||
| return result | |||
| def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx): | |||
| import re | |||
| revision_pred_pseudo_label = pred_pseudo_label.copy() | |||
| revision_pred_pseudo_label = flatten(revision_pred_pseudo_label) | |||
| for idx in revision_idx: | |||
| revision_pred_pseudo_label[idx] = 'P' + str(idx) | |||
| revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) | |||
| revision_pred_pseudo_label[idx] = "P" + str(idx) | |||
| revision_pred_pseudo_label = reform_list(revision_pred_pseudo_label, pred_pseudo_label) | |||
| regex = r"'P\d+'" | |||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) | |||
| def get_query_string(self, pred_pseudo_label, y, revision_idx): | |||
| """ | |||
| Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set | |||
| the returned `Revise_labels` together with the kept labels as the candidates. This is | |||
| Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set | |||
| the returned `Revise_labels` together with the kept labels as the candidates. This is | |||
| a default fuction for demo, users would override this function to adapt to their own | |||
| Prolog file. | |||
| Prolog file. | |||
| """ | |||
| query_string = "logic_forward(" | |||
| query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) | |||
| key_is_none_flag = y is None or (type(y) == list and y[0] is None) | |||
| query_string += ",%s)." % y if not key_is_none_flag else ")." | |||
| return query_string | |||
| def revise_at_idx(self, pred_pseudo_label, y, revision_idx): | |||
| """ | |||
| Revise the predicted pseudo label at specified index positions by querying Prolog. | |||
| This is an overridden function. For more information about the parameters, refer to | |||
| This is an overridden function. For more information about the parameters, refer to | |||
| the function of the same name in class `KBBase`. | |||
| """ | |||
| candidates = [] | |||
| @@ -404,7 +421,7 @@ class PrologKB(KBBase): | |||
| candidate = pred_pseudo_label.copy() | |||
| for i, idx in enumerate(revision_idx): | |||
| candidate[idx] = c[i] | |||
| candidate = reform_idx(candidate, save_pred_pseudo_label) | |||
| candidate = reform_list(candidate, save_pred_pseudo_label) | |||
| candidates.append(candidate) | |||
| return candidates | |||
| @@ -414,4 +431,4 @@ class PrologKB(KBBase): | |||
| f"pseudo_label_list={self.pseudo_label_list!r}, " | |||
| f"defined by " | |||
| f"Prolog file {self.pl_file!r}." | |||
| ) | |||
| ) | |||
| @@ -1,9 +1,9 @@ | |||
| import numpy as np | |||
| from zoopt import Dimension, Objective, Parameter, Opt | |||
| from abl.utils.utils import ( | |||
| from ..utils.utils import ( | |||
| confidence_dist, | |||
| flatten, | |||
| reform_idx, | |||
| reform_list, | |||
| hamming_dist, | |||
| ) | |||
| @@ -191,7 +191,7 @@ class ReasonerBase: | |||
| return max_revision | |||
| def abduce( | |||
| self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 | |||
| self, data_sample, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| """ | |||
| Perform abductive reasoning on the given prediction data. | |||
| @@ -219,9 +219,13 @@ class ReasonerBase: | |||
| A revised pseudo label through abductive reasoning, which is consistent with the | |||
| knowledge base. | |||
| """ | |||
| symbol_num = len(flatten(pred_pseudo_label)) | |||
| symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
| max_revision_num = self._get_max_revision_num(max_revision, symbol_num) | |||
| pred_pseudo_label = data_sample.pred_pseudo_label | |||
| pred_prob = data_sample.pred_prob | |||
| y = data_sample.Y | |||
| if self.use_zoopt: | |||
| solution = self.zoopt_get_solution( | |||
| symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
| @@ -237,20 +241,18 @@ class ReasonerBase: | |||
| return candidate | |||
| def batch_abduce( | |||
| self, pred_probs, pred_pseudo_labels, Ys, max_revision=-1, require_more_revision=0 | |||
| self, data_samples, max_revision=-1, require_more_revision=0 | |||
| ): | |||
| """ | |||
| Perform abductive reasoning on the given prediction data in batches. | |||
| For detailed information, refer to `abduce`. | |||
| """ | |||
| return [ | |||
| self.abduce( | |||
| pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision | |||
| ) | |||
| for pred_prob, pred_pseudo_label, Y in zip( | |||
| pred_probs, pred_pseudo_labels, Ys | |||
| ) | |||
| abduced_pseudo_label = [ | |||
| self.abduce(data_sample, max_revision, require_more_revision) | |||
| for data_sample in data_samples | |||
| ] | |||
| data_samples.abduced_pseudo_label = abduced_pseudo_label | |||
| return abduced_pseudo_label | |||
| # def _batch_abduce_helper(self, args): | |||
| # z, prob, y, max_revision, require_more_revision = args | |||
| @@ -273,12 +275,11 @@ class ReasonerBase: | |||
| if __name__ == "__main__": | |||
| from kb import KBBase, GroundKB, PrologKB | |||
| prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| from abl.structures import ListData | |||
| prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| ################################ | |||
| # Test for MNIST Add reasoning # | |||
| ################################ | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), | |||
| @@ -288,38 +289,54 @@ if __name__ == "__main__": | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class AddGroundKB(GroundKB): | |||
| class AddGroundKB(GroundKB, AddKB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), | |||
| GKB_len_list=[2]): | |||
| super().__init__(pseudo_label_list, GKB_len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def test_add(reasoner): | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| # favor 1 in first one | |||
| prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| # favor 7 in first one | |||
| prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| data_samples_add = ListData() | |||
| data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | |||
| data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] | |||
| data_samples_add.Y = [8, 8, 17, 10] | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) | |||
| print(res) | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) | |||
| print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9] | |||
| print() | |||
| print("AddKB with GKB:") | |||
| print("AddGroundKB:") | |||
| kb = AddGroundKB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| test_add(reasoner) | |||
| print("AddKB without GKB:") | |||
| print("AddKB:") | |||
| kb = AddKB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| test_add(reasoner) | |||
| print("AddKB without GKB, no cache") | |||
| print("AddKB, no cache") | |||
| kb = AddKB(use_cache=False) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| test_add(reasoner) | |||
| @@ -337,45 +354,20 @@ if __name__ == "__main__": | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||
| test_add(reasoner) | |||
| print("AddKB with multiple inputs at once:") | |||
| multiple_prob = [[ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ], | |||
| [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ]] | |||
| kb = AddKB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=1, | |||
| ) | |||
| print(res) | |||
| print() | |||
| ################################ | |||
| #### Test for HWF reasoning #### | |||
| ################################ | |||
| class HwfKB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
| "+", "-", "times", "div"], | |||
| max_err=1e-3, | |||
| use_cache=False, | |||
| ): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| super().__init__(pseudo_label_list, max_err, use_cache) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| @@ -395,7 +387,7 @@ if __name__ == "__main__": | |||
| formula = [mapping[f] for f in formula] | |||
| return eval("".join(formula)) | |||
| class HwfGroundKB(GroundKB): | |||
| class HwfGroundKB(GroundKB, HwfKB): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
| @@ -405,6 +397,17 @@ if __name__ == "__main__": | |||
| ): | |||
| super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
| return False | |||
| return True | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| @@ -415,6 +418,16 @@ if __name__ == "__main__": | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return None | |||
| mapping = {str(i): str(i) for i in range(1, 10)} | |||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval("".join(formula)) | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return None | |||
| @@ -424,87 +437,46 @@ if __name__ == "__main__": | |||
| return eval("".join(formula)) | |||
| def test_hwf(reasoner): | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "+", "2"]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| data_samples_hwf = ListData() | |||
| data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] | |||
| data_samples_hwf.pred_prob = [None, None, None, None] | |||
| data_samples_hwf.Y = [3, 64, 65, 3.17] | |||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "+", "9"]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "8", "8", "8", "8"]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) | |||
| print(res) | |||
| print() | |||
| def test_hwf_multiple(reasoner, max_revisions): | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 64], | |||
| max_revision=max_revisions[0], | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 64], | |||
| max_revision=max_revisions[1], | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 65], | |||
| max_revision=max_revisions[2], | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print("HwfKB with GKB, max_err=0.1") | |||
| print("HwfGroundKB, max_err=0.1:") | |||
| kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB without GKB, max_err=0.1") | |||
| print("HwfKB, max_err=0.1:") | |||
| kb = HwfKB(max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB with GKB, max_err=1") | |||
| print("HwfGroundKB, max_err=1:") | |||
| kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB without GKB, max_err=1") | |||
| print("HwfKB, max_err=1:") | |||
| kb = HwfKB(max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB with multiple inputs at once:") | |||
| kb = HwfKB(max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf_multiple(reasoner, max_revisions=[1,3,3]) | |||
| print("max_revision is float") | |||
| test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) | |||
| ################################ | |||
| #### Test for HED reasoning #### | |||
| ################################ | |||
| class HedKB(PrologKB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| @@ -540,7 +512,7 @@ if __name__ == "__main__": | |||
| return candidate | |||
| def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): | |||
| all_revision_flag = reform_idx(sol.get_x(), pred_res) | |||
| all_revision_flag = reform_list(sol.get_x(), pred_res) | |||
| lefted_idxs = [i for i in range(len(pred_res))] | |||
| candidate_size = [] | |||
| while lefted_idxs: | |||
| @@ -597,28 +569,24 @@ if __name__ == "__main__": | |||
| inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||
| rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | |||
| print("HedKB logic forward") | |||
| print(kb.logic_forward(consist_exs)) | |||
| print("HedKB logic forward:") | |||
| print(kb.logic_forward(consist_exs), end=" ") | |||
| print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | |||
| print() | |||
| print("HedKB consist rule") | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||
| print("HedKB consist rule:") | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ") | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | |||
| print() | |||
| data_sample_hed = ListData() | |||
| data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2] | |||
| data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] | |||
| data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] | |||
| print("HedReasoner abduce") | |||
| res = reasoner.abduce( | |||
| [[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce( | |||
| [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce( | |||
| [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce(data_sample_hed) | |||
| for r in res: | |||
| print(r) | |||
| print() | |||
| print("HedReasoner abduce rules") | |||
| @@ -0,0 +1,2 @@ | |||
| from .base_data_element import BaseDataElement | |||
| from .list_data import ListData | |||
| @@ -0,0 +1,629 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import copy | |||
| from typing import Any, Iterator, Optional, Tuple, Type, Union | |||
| import numpy as np | |||
| import torch | |||
| class BaseDataElement: | |||
| """A base data interface that supports Tensor-like and dict-like | |||
| operations. | |||
| A typical data elements refer to predicted results or ground truth labels | |||
| on a task, such as predicted bboxes, instance masks, semantic | |||
| segmentation masks, etc. Because groundtruth labels and predicted results | |||
| often have similar properties (for example, the predicted bboxes and the | |||
| groundtruth bboxes), MMEngine uses the same abstract data interface to | |||
| encapsulate predicted results and groundtruth labels, and it is recommended | |||
| to use different name conventions to distinguish them, such as using | |||
| ``gt_instances`` and ``pred_instances`` to distinguish between labels and | |||
| predicted results. Additionally, we distinguish data elements at instance | |||
| level, pixel level, and label level. Each of these types has its own | |||
| characteristics. Therefore, MMEngine defines the base class | |||
| ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and | |||
| ``LabelData`` inheriting from ``BaseDataElement`` to represent different | |||
| types of ground truth labels or predictions. | |||
| Another common data element is sample data. A sample data consists of input | |||
| data (such as an image) and its annotations and predictions. In general, | |||
| an image can have multiple types of annotations and/or predictions at the | |||
| same time (for example, both pixel-level semantic segmentation annotations | |||
| and instance-level detection bboxes annotations). All labels and | |||
| predictions of a training sample are often passed between Dataset, Model, | |||
| Visualizer, and Evaluator components. In order to simplify the interface | |||
| between components, we can treat them as a large data element and | |||
| encapsulate them. Such data elements are generally called XXDataSample in | |||
| the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |||
| allows `BaseDataElement` as its attribute. Such a class generally | |||
| encapsulates all the data of a sample in the algorithm library, and its | |||
| attributes generally are various types of data elements. For example, | |||
| MMDetection is assigned by the BaseDataElement to encapsulate all the data | |||
| elements of the sample labeling and prediction of a sample in the | |||
| algorithm library. | |||
| The attributes in ``BaseDataElement`` are divided into two parts, | |||
| the ``metainfo`` and the ``data`` respectively. | |||
| - ``metainfo``: Usually contains the | |||
| information about the image such as filename, | |||
| image_shape, pad_shape, etc. The attributes can be accessed or | |||
| modified by dict-like or object-like operations, such as | |||
| ``.`` (for data access and modification), ``in``, ``del``, | |||
| ``pop(str)``, ``get(str)``, ``metainfo_keys()``, | |||
| ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for | |||
| set or change key-value pairs in metainfo). | |||
| - ``data``: Annotations or model predictions are | |||
| stored. The attributes can be accessed or modified by | |||
| dict-like or object-like operations, such as | |||
| ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |||
| ``values()``, ``items()``. Users can also apply tensor-like | |||
| methods to all :obj:`torch.Tensor` in the ``data_fields``, | |||
| such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, | |||
| ``to_tensor()``, ``.detach()``. | |||
| Args: | |||
| metainfo (dict, optional): A dict contains the meta information | |||
| of single image, such as ``dict(img_shape=(512, 512, 3), | |||
| scale_factor=(1, 1, 1, 1))``. Defaults to None. | |||
| kwargs (dict, optional): A dict contains annotations of single image or | |||
| model predictions. Defaults to None. | |||
| Examples: | |||
| >>> import torch | |||
| >>> from mmengine.structures import BaseDataElement | |||
| >>> gt_instances = BaseDataElement() | |||
| >>> bboxes = torch.rand((5, 4)) | |||
| >>> scores = torch.rand((5,)) | |||
| >>> img_id = 0 | |||
| >>> img_shape = (800, 1333) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=img_id, img_shape=img_shape), | |||
| ... bboxes=bboxes, scores=scores) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) | |||
| >>> # new | |||
| >>> gt_instances1 = gt_instances.new( | |||
| ... metainfo=dict(img_id=1, img_shape=(640, 640)), | |||
| ... bboxes=torch.rand((5, 4)), | |||
| ... scores=torch.rand((5,))) | |||
| >>> gt_instances2 = gt_instances1.new() | |||
| >>> # add and process property | |||
| >>> gt_instances = BaseDataElement() | |||
| >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) | |||
| >>> assert 'img_shape' in gt_instances.metainfo_keys() | |||
| >>> assert 'img_shape' in gt_instances | |||
| >>> assert 'img_shape' not in gt_instances.keys() | |||
| >>> assert 'img_shape' in gt_instances.all_keys() | |||
| >>> print(gt_instances.img_shape) | |||
| (100, 100) | |||
| >>> gt_instances.scores = torch.rand((5,)) | |||
| >>> assert 'scores' in gt_instances.keys() | |||
| >>> assert 'scores' in gt_instances | |||
| >>> assert 'scores' in gt_instances.all_keys() | |||
| >>> assert 'scores' not in gt_instances.metainfo_keys() | |||
| >>> print(gt_instances.scores) | |||
| tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) | |||
| >>> gt_instances.bboxes = torch.rand((5, 4)) | |||
| >>> assert 'bboxes' in gt_instances.keys() | |||
| >>> assert 'bboxes' in gt_instances | |||
| >>> assert 'bboxes' in gt_instances.all_keys() | |||
| >>> assert 'bboxes' not in gt_instances.metainfo_keys() | |||
| >>> print(gt_instances.bboxes) | |||
| tensor([[0.0900, 0.0424, 0.1755, 0.4469], | |||
| [0.8648, 0.0592, 0.3484, 0.0913], | |||
| [0.5808, 0.1909, 0.6165, 0.7088], | |||
| [0.5490, 0.4209, 0.9416, 0.2374], | |||
| [0.3652, 0.1218, 0.8805, 0.7523]]) | |||
| >>> # delete and change property | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=0, img_shape=(640, 640)), | |||
| ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) | |||
| >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) | |||
| >>> gt_instances.img_shape # (1280, 1280) | |||
| >>> gt_instances.bboxes = gt_instances.bboxes * 2 | |||
| >>> gt_instances.get('img_shape', None) # (1280, 1280) | |||
| >>> gt_instances.get('bboxes', None) # 6x4 tensor | |||
| >>> del gt_instances.img_shape | |||
| >>> del gt_instances.bboxes | |||
| >>> assert 'img_shape' not in gt_instances | |||
| >>> assert 'bboxes' not in gt_instances | |||
| >>> gt_instances.pop('img_shape', None) # None | |||
| >>> gt_instances.pop('bboxes', None) # None | |||
| >>> # Tensor-like | |||
| >>> cuda_instances = gt_instances.cuda() | |||
| >>> cuda_instances = gt_instances.to('cuda:0') | |||
| >>> cpu_instances = cuda_instances.cpu() | |||
| >>> cpu_instances = cuda_instances.to('cpu') | |||
| >>> fp16_instances = cuda_instances.to( | |||
| ... device=None, dtype=torch.float16, non_blocking=False, | |||
| ... copy=False, memory_format=torch.preserve_format) | |||
| >>> cpu_instances = cuda_instances.detach() | |||
| >>> np_instances = cpu_instances.numpy() | |||
| >>> metainfo = dict(img_shape=(800, 1196, 3)) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |||
| >>> sample = BaseDataElement(metainfo=metainfo, | |||
| ... gt_instances=gt_instances) | |||
| >>> print(sample) | |||
| <BaseDataElement( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| DATA FIELDS | |||
| gt_instances: <BaseDataElement( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([0, 1, 2, 3]) | |||
| ) at 0x7f0ec5eadc70> | |||
| ) at 0x7f0fea49e130> | |||
| >>> # inheritance | |||
| >>> class DetDataSample(BaseDataElement): | |||
| ... @property | |||
| ... def proposals(self): | |||
| ... return self._proposals | |||
| ... @proposals.setter | |||
| ... def proposals(self, value): | |||
| ... self.set_field(value, '_proposals', dtype=BaseDataElement) | |||
| ... @proposals.deleter | |||
| ... def proposals(self): | |||
| ... del self._proposals | |||
| ... @property | |||
| ... def gt_instances(self): | |||
| ... return self._gt_instances | |||
| ... @gt_instances.setter | |||
| ... def gt_instances(self, value): | |||
| ... self.set_field(value, '_gt_instances', | |||
| ... dtype=BaseDataElement) | |||
| ... @gt_instances.deleter | |||
| ... def gt_instances(self): | |||
| ... del self._gt_instances | |||
| ... @property | |||
| ... def pred_instances(self): | |||
| ... return self._pred_instances | |||
| ... @pred_instances.setter | |||
| ... def pred_instances(self, value): | |||
| ... self.set_field(value, '_pred_instances', | |||
| ... dtype=BaseDataElement) | |||
| ... @pred_instances.deleter | |||
| ... def pred_instances(self): | |||
| ... del self._pred_instances | |||
| >>> det_sample = DetDataSample() | |||
| >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |||
| >>> det_sample.proposals = proposals | |||
| >>> assert 'proposals' in det_sample | |||
| >>> assert det_sample.proposals == proposals | |||
| >>> del det_sample.proposals | |||
| >>> assert 'proposals' not in det_sample | |||
| >>> with self.assertRaises(AssertionError): | |||
| ... det_sample.proposals = torch.rand((5, 4)) | |||
| """ | |||
| def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |||
| self._metainfo_fields: set = set() | |||
| self._data_fields: set = set() | |||
| if metainfo is not None: | |||
| self.set_metainfo(metainfo=metainfo) | |||
| if kwargs: | |||
| self.set_data(kwargs) | |||
| def set_metainfo(self, metainfo: dict) -> None: | |||
| """Set or change key-value pairs in ``metainfo_field`` by parameter | |||
| ``metainfo``. | |||
| Args: | |||
| metainfo (dict): A dict contains the meta information | |||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||
| """ | |||
| assert isinstance( | |||
| metainfo, dict | |||
| ), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||
| meta = copy.deepcopy(metainfo) | |||
| for k, v in meta.items(): | |||
| self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | |||
| def set_data(self, data: dict) -> None: | |||
| """Set or change key-value pairs in ``data_field`` by parameter | |||
| ``data``. | |||
| Args: | |||
| data (dict): A dict contains annotations of image or | |||
| model predictions. | |||
| """ | |||
| assert isinstance(data, dict), f"data should be a `dict` but got {data}" | |||
| for k, v in data.items(): | |||
| # Use `setattr()` rather than `self.set_field` to allow `set_data` | |||
| # to set property method. | |||
| setattr(self, k, v) | |||
| def update(self, instance: "BaseDataElement") -> None: | |||
| """The update() method updates the BaseDataElement with the elements | |||
| from another BaseDataElement object. | |||
| Args: | |||
| instance (BaseDataElement): Another BaseDataElement object for | |||
| update the current object. | |||
| """ | |||
| assert isinstance( | |||
| instance, BaseDataElement | |||
| ), f"instance should be a `BaseDataElement` but got {type(instance)}" | |||
| self.set_metainfo(dict(instance.metainfo_items())) | |||
| self.set_data(dict(instance.items())) | |||
| def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": | |||
| """Return a new data element with same type. If ``metainfo`` and | |||
| ``data`` are None, the new data element will have same metainfo and | |||
| data. If metainfo or data is not None, the new result will overwrite it | |||
| with the input value. | |||
| Args: | |||
| metainfo (dict, optional): A dict contains the meta information | |||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||
| Defaults to None. | |||
| kwargs (dict): A dict contains annotations of image or | |||
| model predictions. | |||
| Returns: | |||
| BaseDataElement: A new data element with same type. | |||
| """ | |||
| new_data = self.__class__() | |||
| if metainfo is not None: | |||
| new_data.set_metainfo(metainfo) | |||
| else: | |||
| new_data.set_metainfo(dict(self.metainfo_items())) | |||
| if kwargs: | |||
| new_data.set_data(kwargs) | |||
| else: | |||
| new_data.set_data(dict(self.items())) | |||
| return new_data | |||
| def clone(self): | |||
| """Deep copy the current data element. | |||
| Returns: | |||
| BaseDataElement: The copy of current data element. | |||
| """ | |||
| clone_data = self.__class__() | |||
| clone_data.set_metainfo(dict(self.metainfo_items())) | |||
| clone_data.set_data(dict(self.items())) | |||
| return clone_data | |||
| def keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in data_fields. | |||
| """ | |||
| # We assume that the name of the attribute related to property is | |||
| # '_' + the name of the property. We use this rule to filter out | |||
| # private keys. | |||
| # TODO: Use a more robust way to solve this problem | |||
| private_keys = { | |||
| "_" + key | |||
| for key in self._data_fields | |||
| if isinstance(getattr(type(self), key, None), property) | |||
| } | |||
| return list(self._data_fields - private_keys) | |||
| def metainfo_keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in metainfo_fields. | |||
| """ | |||
| return list(self._metainfo_fields) | |||
| def values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in data. | |||
| """ | |||
| return [getattr(self, k) for k in self.keys()] | |||
| def metainfo_values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in metainfo. | |||
| """ | |||
| return [getattr(self, k) for k in self.metainfo_keys()] | |||
| def all_keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in metainfo and data. | |||
| """ | |||
| return self.metainfo_keys() + self.keys() | |||
| def all_values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in metainfo and data. | |||
| """ | |||
| return self.metainfo_values() + self.values() | |||
| def all_items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``metainfo`` and ``data``. | |||
| """ | |||
| for k in self.all_keys(): | |||
| yield (k, getattr(self, k)) | |||
| def items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``data``. | |||
| """ | |||
| for k in self.keys(): | |||
| yield (k, getattr(self, k)) | |||
| def metainfo_items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``metainfo``. | |||
| """ | |||
| for k in self.metainfo_keys(): | |||
| yield (k, getattr(self, k)) | |||
| @property | |||
| def metainfo(self) -> dict: | |||
| """dict: A dict contains metainfo of current data element.""" | |||
| return dict(self.metainfo_items()) | |||
| def __setattr__(self, name: str, value: Any): | |||
| """setattr is only used to set data.""" | |||
| if name in ("_metainfo_fields", "_data_fields"): | |||
| if not hasattr(self, name): | |||
| super().__setattr__(name, value) | |||
| else: | |||
| raise AttributeError( | |||
| f"{name} has been used as a " | |||
| "private attribute, which is immutable." | |||
| ) | |||
| else: | |||
| self.set_field(name=name, value=value, field_type="data", dtype=None) | |||
| def __delattr__(self, item: str): | |||
| """Delete the item in dataelement. | |||
| Args: | |||
| item (str): The key to delete. | |||
| """ | |||
| if item in ("_metainfo_fields", "_data_fields"): | |||
| raise AttributeError( | |||
| f"{item} has been used as a " "private attribute, which is immutable." | |||
| ) | |||
| super().__delattr__(item) | |||
| if item in self._metainfo_fields: | |||
| self._metainfo_fields.remove(item) | |||
| elif item in self._data_fields: | |||
| self._data_fields.remove(item) | |||
| # dict-like methods | |||
| __delitem__ = __delattr__ | |||
| def get(self, key, default=None) -> Any: | |||
| """Get property in data and metainfo as the same as python.""" | |||
| # Use `getattr()` rather than `self.__dict__.get()` to allow getting | |||
| # properties. | |||
| return getattr(self, key, default) | |||
| def pop(self, *args) -> Any: | |||
| """Pop property in data and metainfo as the same as python.""" | |||
| assert len(args) < 3, "``pop`` get more than 2 arguments" | |||
| name = args[0] | |||
| if name in self._metainfo_fields: | |||
| self._metainfo_fields.remove(args[0]) | |||
| return self.__dict__.pop(*args) | |||
| elif name in self._data_fields: | |||
| self._data_fields.remove(args[0]) | |||
| return self.__dict__.pop(*args) | |||
| # with default value | |||
| elif len(args) == 2: | |||
| return args[1] | |||
| else: | |||
| # don't just use 'self.__dict__.pop(*args)' for only popping key in | |||
| # metainfo or data | |||
| raise KeyError(f"{args[0]} is not contained in metainfo or data") | |||
| def __contains__(self, item: str) -> bool: | |||
| """Whether the item is in dataelement. | |||
| Args: | |||
| item (str): The key to inquire. | |||
| """ | |||
| return item in self._data_fields or item in self._metainfo_fields | |||
| def set_field( | |||
| self, | |||
| value: Any, | |||
| name: str, | |||
| dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |||
| field_type: str = "data", | |||
| ) -> None: | |||
| """Special method for set union field, used as property.setter | |||
| functions.""" | |||
| assert field_type in ["metainfo", "data"] | |||
| if dtype is not None: | |||
| assert isinstance( | |||
| value, dtype | |||
| ), f"{value} should be a {dtype} but got {type(value)}" | |||
| if field_type == "metainfo": | |||
| if name in self._data_fields: | |||
| raise AttributeError( | |||
| f"Cannot set {name} to be a field of metainfo " | |||
| f"because {name} is already a data field" | |||
| ) | |||
| self._metainfo_fields.add(name) | |||
| else: | |||
| if name in self._metainfo_fields: | |||
| raise AttributeError( | |||
| f"Cannot set {name} to be a field of data " | |||
| f"because {name} is already a metainfo field" | |||
| ) | |||
| self._data_fields.add(name) | |||
| super().__setattr__(name, value) | |||
| # Tensor-like methods | |||
| def to(self, *args, **kwargs) -> "BaseDataElement": | |||
| """Apply same name function to all tensors in data_fields.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if hasattr(v, "to"): | |||
| v = v.to(*args, **kwargs) | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def cpu(self) -> "BaseDataElement": | |||
| """Convert all tensors to CPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.cpu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def cuda(self) -> "BaseDataElement": | |||
| """Convert all tensors to GPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.cuda() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def npu(self) -> "BaseDataElement": | |||
| """Convert all tensors to NPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.npu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def mlu(self) -> "BaseDataElement": | |||
| """Convert all tensors to MLU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.mlu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def detach(self) -> "BaseDataElement": | |||
| """Detach all tensors in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.detach() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def numpy(self) -> "BaseDataElement": | |||
| """Convert all tensors to np.ndarray in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.detach().cpu().numpy() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def to_tensor(self) -> "BaseDataElement": | |||
| """Convert all np.ndarray to tensor in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| data = {} | |||
| if isinstance(v, np.ndarray): | |||
| v = torch.from_numpy(v) | |||
| data[k] = v | |||
| elif isinstance(v, BaseDataElement): | |||
| v = v.to_tensor() | |||
| data[k] = v | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def to_dict(self) -> dict: | |||
| """Convert BaseDataElement to dict.""" | |||
| return { | |||
| k: v.to_dict() if isinstance(v, BaseDataElement) else v | |||
| for k, v in self.all_items() | |||
| } | |||
| def __repr__(self) -> str: | |||
| """Represent the object.""" | |||
| def _addindent(s_: str, num_spaces: int) -> str: | |||
| """This func is modified from `pytorch` https://github.com/pytorch/ | |||
| pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu | |||
| les/module.py#L29. | |||
| Args: | |||
| s_ (str): The string to add spaces. | |||
| num_spaces (int): The num of space to add. | |||
| Returns: | |||
| str: The string after add indent. | |||
| """ | |||
| s = s_.split("\n") | |||
| # don't do anything for single-line stuff | |||
| if len(s) == 1: | |||
| return s_ | |||
| first = s.pop(0) | |||
| s = [(num_spaces * " ") + line for line in s] | |||
| s = "\n".join(s) # type: ignore | |||
| s = first + "\n" + s # type: ignore | |||
| return s # type: ignore | |||
| def dump(obj: Any) -> str: | |||
| """Represent the object. | |||
| Args: | |||
| obj (Any): The obj to represent. | |||
| Returns: | |||
| str: The represented str. | |||
| """ | |||
| _repr = "" | |||
| if isinstance(obj, dict): | |||
| for k, v in obj.items(): | |||
| _repr += f"\n{k}: {_addindent(dump(v), 4)}" | |||
| elif isinstance(obj, BaseDataElement): | |||
| _repr += "\n\n META INFORMATION" | |||
| metainfo_items = dict(obj.metainfo_items()) | |||
| _repr += _addindent(dump(metainfo_items), 4) | |||
| _repr += "\n\n DATA FIELDS" | |||
| items = dict(obj.items()) | |||
| _repr += _addindent(dump(items), 4) | |||
| classname = obj.__class__.__name__ | |||
| _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" | |||
| else: | |||
| _repr += repr(obj) | |||
| return _repr | |||
| return dump(self) | |||
| @@ -0,0 +1,305 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import itertools | |||
| from collections.abc import Sized | |||
| from typing import Any, List, Union | |||
| import numpy as np | |||
| import torch | |||
| from ..utils import flatten as flatten_list | |||
| from ..utils import to_hashable | |||
| from .base_data_element import BaseDataElement | |||
| BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |||
| LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |||
| IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] | |||
| # Modified from | |||
| # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||
| class ListData(BaseDataElement): | |||
| """Data structure for instance-level annotations or predictions. | |||
| Subclass of :class:`BaseDataElement`. All value in `data_fields` | |||
| should have the same length. This design refer to | |||
| https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |||
| ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value | |||
| in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, | |||
| and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. | |||
| Examples: | |||
| >>> # custom data structure | |||
| >>> class TmpObject: | |||
| ... def __init__(self, tmp) -> None: | |||
| ... assert isinstance(tmp, list) | |||
| ... self.tmp = tmp | |||
| ... def __len__(self): | |||
| ... return len(self.tmp) | |||
| ... def __getitem__(self, item): | |||
| ... if isinstance(item, int): | |||
| ... if item >= len(self) or item < -len(self): # type:ignore | |||
| ... raise IndexError(f'Index {item} out of range!') | |||
| ... else: | |||
| ... # keep the dimension | |||
| ... item = slice(item, None, len(self)) | |||
| ... return TmpObject(self.tmp[item]) | |||
| ... @staticmethod | |||
| ... def cat(tmp_objs): | |||
| ... assert all(isinstance(results, TmpObject) for results in tmp_objs) | |||
| ... if len(tmp_objs) == 1: | |||
| ... return tmp_objs[0] | |||
| ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] | |||
| ... tmp_list = list(itertools.chain(*tmp_list)) | |||
| ... new_data = TmpObject(tmp_list) | |||
| ... return new_data | |||
| ... def __repr__(self): | |||
| ... return str(self.tmp) | |||
| >>> from mmengine.structures import ListData | |||
| >>> import numpy as np | |||
| >>> import torch | |||
| >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |||
| >>> instance_data = ListData(metainfo=img_meta) | |||
| >>> 'img_shape' in instance_data | |||
| True | |||
| >>> instance_data.det_labels = torch.LongTensor([2, 3]) | |||
| >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) | |||
| >>> instance_data.bboxes = torch.rand((2, 4)) | |||
| >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) | |||
| >>> len(instance_data) | |||
| 2 | |||
| >>> print(instance_data) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([2, 3]) | |||
| det_scores: tensor([0.8000, 0.7000]) | |||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |||
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] | |||
| ) at 0x7fb492de6280> | |||
| >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] | |||
| >>> sorted_results.det_scores | |||
| tensor([0.7000, 0.8000]) | |||
| >>> print(instance_data[instance_data.det_scores > 0.75]) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([2]) | |||
| det_scores: tensor([0.8000]) | |||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) | |||
| polygons: [[1, 2, 3, 4]] | |||
| ) at 0x7f64ecf0ec40> | |||
| >>> print(instance_data[instance_data.det_scores > 1]) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([], dtype=torch.int64) | |||
| det_scores: tensor([]) | |||
| bboxes: tensor([], size=(0, 4)) | |||
| polygons: [] | |||
| ) at 0x7f660a6a7f70> | |||
| >>> print(instance_data.cat([instance_data, instance_data])) | |||
| <ListData( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| pad_shape: (800, 1216, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([2, 3, 2, 3]) | |||
| det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) | |||
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||
| [0.8101, 0.3105, 0.5123, 0.6263], | |||
| [0.4997, 0.7707, 0.0595, 0.4188], | |||
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |||
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] | |||
| ) at 0x7f203542feb0> | |||
| """ | |||
| def __setattr__(self, name: str, value: list): | |||
| """setattr is only used to set data. | |||
| The value must have the attribute of `__len__` and have the same length | |||
| of `ListData`. | |||
| """ | |||
| if name in ("_metainfo_fields", "_data_fields"): | |||
| if not hasattr(self, name): | |||
| super().__setattr__(name, value) | |||
| else: | |||
| raise AttributeError( | |||
| f"{name} has been used as a " "private attribute, which is immutable." | |||
| ) | |||
| else: | |||
| # assert isinstance(value, list), "value must be of type `list`" | |||
| # if len(self) > 0: | |||
| # assert len(value) == len(self), ( | |||
| # "The length of " | |||
| # f"values {len(value)} is " | |||
| # "not consistent with " | |||
| # "the length of this " | |||
| # ":obj:`ListData` " | |||
| # f"{len(self)}" | |||
| # ) | |||
| super().__setattr__(name, value) | |||
| __setitem__ = __setattr__ | |||
| def __getitem__(self, item: IndexType) -> "ListData": | |||
| """ | |||
| Args: | |||
| item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |||
| :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |||
| Get the corresponding values according to item. | |||
| Returns: | |||
| :obj:`ListData`: Corresponding values. | |||
| """ | |||
| assert isinstance(item, IndexType.__args__) | |||
| if isinstance(item, list): | |||
| item = np.array(item) | |||
| if isinstance(item, np.ndarray): | |||
| # The default int type of numpy is platform dependent, int32 for | |||
| # windows and int64 for linux. `torch.Tensor` requires the index | |||
| # should be int64, therefore we simply convert it to int64 here. | |||
| # More details in https://github.com/numpy/numpy/issues/9464 | |||
| item = item.astype(np.int64) if item.dtype == np.int32 else item | |||
| item = torch.from_numpy(item) | |||
| if isinstance(item, str): | |||
| return getattr(self, item) | |||
| new_data = self.__class__(metainfo=self.metainfo) | |||
| if isinstance(item, torch.Tensor): | |||
| assert item.dim() == 1, "Only support to get the" " values along the first dimension." | |||
| for k, v in self.items(): | |||
| if v is None: | |||
| new_data[k] = None | |||
| elif isinstance(v, torch.Tensor): | |||
| new_data[k] = v[item] | |||
| elif isinstance(v, np.ndarray): | |||
| new_data[k] = v[item.cpu().numpy()] | |||
| elif isinstance(v, (str, list, tuple)) or ( | |||
| hasattr(v, "__getitem__") and hasattr(v, "cat") | |||
| ): | |||
| # convert to indexes from BoolTensor | |||
| if isinstance(item, BoolTypeTensor.__args__): | |||
| indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() | |||
| else: | |||
| indexes = item.cpu().numpy().tolist() | |||
| slice_list = [] | |||
| if indexes: | |||
| for index in indexes: | |||
| slice_list.append(slice(index, None, len(v))) | |||
| else: | |||
| slice_list.append(slice(None, 0, None)) | |||
| r_list = [v[s] for s in slice_list] | |||
| if isinstance(v, (str, list, tuple)): | |||
| new_value = r_list[0] | |||
| for r in r_list[1:]: | |||
| new_value = new_value + r | |||
| else: | |||
| new_value = v.cat(r_list) | |||
| new_data[k] = new_value | |||
| else: | |||
| raise ValueError( | |||
| f"The type of `{k}` is `{type(v)}`, which has no " | |||
| "attribute of `cat`, so it does not " | |||
| "support slice with `bool`" | |||
| ) | |||
| else: | |||
| # item is a slice or int | |||
| for k, v in self.items(): | |||
| if v is None: | |||
| new_data[k] = None | |||
| else: | |||
| new_data[k] = v[item] | |||
| return new_data # type:ignore | |||
| @staticmethod | |||
| def cat(instances_list: List["ListData"]) -> "ListData": | |||
| """Concat the instances of all :obj:`ListData` in the list. | |||
| Note: To ensure that cat returns as expected, make sure that | |||
| all elements in the list must have exactly the same keys. | |||
| Args: | |||
| instances_list (list[:obj:`ListData`]): A list | |||
| of :obj:`ListData`. | |||
| Returns: | |||
| :obj:`ListData` | |||
| """ | |||
| assert all(isinstance(results, ListData) for results in instances_list) | |||
| assert len(instances_list) > 0 | |||
| if len(instances_list) == 1: | |||
| return instances_list[0] | |||
| # metainfo and data_fields must be exactly the | |||
| # same for each element to avoid exceptions. | |||
| field_keys_list = [instances.all_keys() for instances in instances_list] | |||
| assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( | |||
| set(itertools.chain(*field_keys_list)) | |||
| ) == len(field_keys_list[0]), ( | |||
| "There are different keys in " | |||
| "`instances_list`, which may " | |||
| "cause the cat operation " | |||
| "to fail. Please make sure all " | |||
| "elements in `instances_list` " | |||
| "have the exact same key." | |||
| ) | |||
| new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) | |||
| for k in instances_list[0].keys(): | |||
| values = [results[k] for results in instances_list] | |||
| v0 = values[0] | |||
| if isinstance(v0, torch.Tensor): | |||
| new_values = torch.cat(values, dim=0) | |||
| elif isinstance(v0, np.ndarray): | |||
| new_values = np.concatenate(values, axis=0) | |||
| elif isinstance(v0, (str, list, tuple)): | |||
| new_values = v0[:] | |||
| for v in values[1:]: | |||
| new_values += v | |||
| elif hasattr(v0, "cat"): | |||
| new_values = v0.cat(values) | |||
| else: | |||
| raise ValueError( | |||
| f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`" | |||
| ) | |||
| new_data[k] = new_values | |||
| return new_data # type:ignore | |||
| def flatten(self, item: IndexType) -> List: | |||
| """Flatten self[item]. | |||
| Returns: | |||
| list: Flattened data fields. | |||
| """ | |||
| return flatten_list(self[item]) | |||
| def elements_num(self, item: IndexType) -> int: | |||
| """int: The number of elements in self[item].""" | |||
| return len(self.flatten(item)) | |||
| def to_tuple(self, item: IndexType) -> tuple: | |||
| """tuple: The data fields in self[item] converted to tuple.""" | |||
| return to_hashable(self[item]) | |||
| def __len__(self) -> int: | |||
| """int: The length of ListData.""" | |||
| if len(self._data_fields) > 0: | |||
| one_element = next(iter(self._data_fields)) | |||
| return len(getattr(self, one_element)) | |||
| # return len(self.values()[0]) | |||
| else: | |||
| return 0 | |||
| @@ -1,2 +1,3 @@ | |||
| from .cache import Cache, abl_cache | |||
| from .logger import ABLLogger, print_log | |||
| from .utils import * | |||
| from .utils import * | |||
| @@ -0,0 +1,104 @@ | |||
| import pickle | |||
| import os | |||
| import os.path as osp | |||
| from typing import Callable, Generic, TypeVar | |||
| from .logger import print_log, ABLLogger | |||
| K = TypeVar("K") | |||
| T = TypeVar("T") | |||
| PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | |||
| class Cache(Generic[K, T]): | |||
| def __init__(self, func: Callable[[K], T]): | |||
| """Create cache | |||
| :param func: Function this cache evaluates | |||
| :param cache: If true, do in memory caching. | |||
| :param cache_root: If not None, cache to files at the provided path. | |||
| :param key_func: Convert the key into a hashable object if needed | |||
| """ | |||
| self.func = func | |||
| self.has_init = False | |||
| def __getitem__(self, obj, *args) -> T: | |||
| return self.get_from_dict(obj, *args) | |||
| def clear_cache(self): | |||
| """Invalidate entire cache.""" | |||
| self.cache_dict.clear() | |||
| def _init_cache(self, obj): | |||
| if self.has_init: | |||
| return | |||
| self.cache = True | |||
| self.cache_dict = dict() | |||
| self.key_func = obj.key_func | |||
| self.max_size = obj.max_cache_size | |||
| self.hits, self.misses = 0, 0 | |||
| self.full = False | |||
| self.root = [] # root of the circular doubly linked list | |||
| self.root[:] = [self.root, self.root, None, None] | |||
| self.has_init = True | |||
| def get_from_dict(self, obj, *args) -> T: | |||
| """Implements dict based cache.""" | |||
| pred_pseudo_label, y, *res_args = args | |||
| cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) | |||
| link = self.cache_dict.get(cache_key) | |||
| if link is not None: | |||
| # Move the link to the front of the circular queue | |||
| link_prev, link_next, _key, result = link | |||
| link_prev[NEXT] = link_next | |||
| link_next[PREV] = link_prev | |||
| last = self.root[PREV] | |||
| last[NEXT] = self.root[PREV] = link | |||
| link[PREV] = last | |||
| link[NEXT] = self.root | |||
| self.hits += 1 | |||
| return result | |||
| self.misses += 1 | |||
| result = self.func(obj, *args) | |||
| if self.full: | |||
| # Use the old root to store the new key and result. | |||
| oldroot = self.root | |||
| oldroot[KEY] = cache_key | |||
| oldroot[RESULT] = result | |||
| # Empty the oldest link and make it the new root. | |||
| self.root = oldroot[NEXT] | |||
| oldkey = self.root[KEY] | |||
| oldresult = self.root[RESULT] | |||
| self.root[KEY] = self.root[RESULT] = None | |||
| # Now update the cache dictionary. | |||
| del self.cache_dict[oldkey] | |||
| self.cache_dict[cache_key] = oldroot | |||
| else: | |||
| # Put result in a new link at the front of the queue. | |||
| last = self.root[PREV] | |||
| link = [last, self.root, cache_key, result] | |||
| last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link | |||
| if isinstance(self.max_size, int): | |||
| self.full = len(self.cache_dict) >= self.max_size | |||
| return result | |||
| def abl_cache(): | |||
| def decorator(func): | |||
| cache_instance = Cache(func) | |||
| def wrapper(obj, *args): | |||
| if obj.use_cache: | |||
| cache_instance._init_cache(obj) | |||
| return cache_instance.get_from_dict(obj, *args) | |||
| else: | |||
| return func(obj, *args) | |||
| return wrapper | |||
| return decorator | |||
| @@ -1,6 +1,7 @@ | |||
| import numpy as np | |||
| from itertools import chain | |||
| import numpy as np | |||
| def flatten(nested_list): | |||
| """ | |||
| @@ -15,6 +16,11 @@ def flatten(nested_list): | |||
| ------- | |||
| list | |||
| A flattened version of the input list. | |||
| Raises | |||
| ------ | |||
| TypeError | |||
| If the input object is not a list. | |||
| """ | |||
| if not isinstance(nested_list, list): | |||
| raise TypeError("Input must be of type list.") | |||
| @@ -25,7 +31,7 @@ def flatten(nested_list): | |||
| return list(chain.from_iterable(nested_list)) | |||
| def reform_idx(flattened_list, structured_list): | |||
| def reform_list(flattened_list, structured_list): | |||
| """ | |||
| Reform the index based on structured_list structure. | |||
| @@ -41,6 +47,9 @@ def reform_idx(flattened_list, structured_list): | |||
| list | |||
| A reformed list that mimics the structure of structured_list. | |||
| """ | |||
| # if not isinstance(flattened_list, list): | |||
| # raise TypeError("Input must be of type list.") | |||
| if not isinstance(structured_list[0], (list, tuple)): | |||
| return flattened_list | |||
| @@ -80,7 +89,7 @@ def hamming_dist(pred_pseudo_label, candidates): | |||
| return np.sum(pred_pseudo_label != candidates, axis=1) | |||
| def confidence_dist(pred_prob, candidates_idx): | |||
| def confidence_dist(pred_prob, candidates): | |||
| """ | |||
| Compute the confidence distance between prediction probabilities and candidates. | |||
| @@ -89,7 +98,7 @@ def confidence_dist(pred_prob, candidates_idx): | |||
| pred_prob : list of numpy.ndarray | |||
| Prediction probability distributions, each element is an ndarray | |||
| representing the probability distribution of a particular prediction. | |||
| candidates_idx : list of list of int | |||
| candidates : list of list of int | |||
| Index of candidate labels, each element is a list of indexes being considered | |||
| as a candidate correction. | |||
| @@ -99,8 +108,8 @@ def confidence_dist(pred_prob, candidates_idx): | |||
| Confidence distances computed for each candidate. | |||
| """ | |||
| pred_prob = np.clip(pred_prob, 1e-9, 1) | |||
| _, cols = np.indices((len(candidates_idx), len(candidates_idx[0]))) | |||
| return 1 - np.prod(pred_prob[cols, candidates_idx], axis=1) | |||
| _, cols = np.indices((len(candidates), len(candidates[0]))) | |||
| return 1 - np.prod(pred_prob[cols, candidates], axis=1) | |||
| def block_sample(X, Z, Y, sample_num, seg_idx): | |||
| @@ -154,6 +163,7 @@ def to_hashable(x): | |||
| return tuple(to_hashable(item) for item in x) | |||
| return x | |||
| def restore_from_hashable(x): | |||
| """ | |||
| Convert a nested tuple back to a nested list. | |||
| @@ -170,10 +180,49 @@ def restore_from_hashable(x): | |||
| otherwise the original input. | |||
| """ | |||
| if isinstance(x, tuple): | |||
| return [hashable_to_list(item) for item in x] | |||
| return [restore_from_hashable(item) for item in x] | |||
| return x | |||
| def calculate_revision_num(parameter, total_length): | |||
| """ | |||
| Convert a float parameter to an integer, based on a total length. | |||
| Parameters | |||
| ---------- | |||
| parameter : int or float | |||
| The parameter to convert. If float, it should be between 0 and 1. | |||
| If int, it should be non-negative. If -1, it will be replaced with total_length. | |||
| total_length : int | |||
| The total length to calculate the parameter from if it's a fraction. | |||
| Returns | |||
| ------- | |||
| int | |||
| The calculated parameter. | |||
| Raises | |||
| ------ | |||
| TypeError | |||
| If parameter is not an int or a float. | |||
| ValueError | |||
| If parameter is a float not in [0, 1] or an int below 0. | |||
| """ | |||
| if not isinstance(parameter, (int, float)): | |||
| raise TypeError("Parameter must be of type int or float.") | |||
| if parameter == -1: | |||
| return total_length | |||
| elif isinstance(parameter, float): | |||
| if not (0 <= parameter <= 1): | |||
| raise ValueError("If parameter is a float, it must be between 0 and 1.") | |||
| return round(total_length * parameter) | |||
| else: | |||
| if parameter < 0: | |||
| raise ValueError("If parameter is an int, it must be non-negative.") | |||
| return parameter | |||
| if __name__ == "__main__": | |||
| A = np.array( | |||
| [ | |||
| @@ -227,4 +276,5 @@ if __name__ == "__main__": | |||
| ) | |||
| B = [[0, 9, 3], [0, 11, 4]] | |||
| print(ori_confidence_dist(A, B)) | |||
| print(confidence_dist(A, B)) | |||
| @@ -19,17 +19,17 @@ class HEDBridge(SimpleBridge): | |||
| def __init__( | |||
| self, | |||
| model: ABLModel, | |||
| abducer: ReasonerBase, | |||
| reasoner: ReasonerBase, | |||
| metric_list: BaseMetric, | |||
| ) -> None: | |||
| super().__init__(model, abducer, metric_list) | |||
| super().__init__(model, reasoner, metric_list) | |||
| def pretrain(self, weights_dir): | |||
| if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")): | |||
| print_log("Pretrain Start", logger="current") | |||
| cls_autoencoder = SymbolNetAutoencoder( | |||
| num_classes=len(self.abducer.kb.pseudo_label_list) | |||
| num_classes=len(self.reasoner.kb.pseudo_label_list) | |||
| ) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| criterion = torch.nn.MSELoss() | |||
| @@ -74,7 +74,7 @@ class HEDBridge(SimpleBridge): | |||
| max_revision=-1, | |||
| require_more_revision=0, | |||
| ): | |||
| return self.abducer.abduce( | |||
| return self.reasoner.abduce( | |||
| (pred_label, pred_prob, pseudo_label, Y), | |||
| max_revision, | |||
| require_more_revision, | |||
| @@ -86,8 +86,8 @@ class HEDBridge(SimpleBridge): | |||
| pred_pseudo_label_list = [] | |||
| abduced_pseudo_label_list = [] | |||
| for _mapping in candidate_mappings: | |||
| self.abducer.mapping = _mapping | |||
| self.abducer.set_remapping() | |||
| self.reasoner.mapping = _mapping | |||
| self.reasoner.set_remapping() | |||
| pred_pseudo_label = self.label_to_pseudo_label(pred_label) | |||
| abduced_pseudo_label = self.abduce_pseudo_label( | |||
| pred_label, pred_prob, pred_pseudo_label, Y, 20 | |||
| @@ -100,8 +100,8 @@ class HEDBridge(SimpleBridge): | |||
| max_revisible_instances = max(mapping_score) | |||
| return_idx = mapping_score.index(max_revisible_instances) | |||
| self.abducer.mapping = candidate_mappings[return_idx] | |||
| self.abducer.set_remapping() | |||
| self.reasoner.mapping = candidate_mappings[return_idx] | |||
| self.reasoner.set_remapping() | |||
| return abduced_pseudo_label_list[return_idx] | |||
| def check_training_impact(self, filtered_X, filtered_abduced_label, X): | |||
| @@ -137,7 +137,7 @@ class HEDBridge(SimpleBridge): | |||
| pred_pseudo_label = self.label_to_pseudo_label(pred_label) | |||
| consistent_num = sum( | |||
| [ | |||
| self.abducer.kb.consist_rule(instance, rule) | |||
| self.reasoner.kb.consist_rule(instance, rule) | |||
| for instance in pred_pseudo_label | |||
| ] | |||
| ) | |||
| @@ -159,11 +159,11 @@ class HEDBridge(SimpleBridge): | |||
| pred_pseudo_label = self.label_to_pseudo_label(pred_label) | |||
| consistent_instance = [] | |||
| for instance in pred_pseudo_label: | |||
| if self.abducer.kb.logic_forward([instance]): | |||
| if self.reasoner.kb.logic_forward([instance]): | |||
| consistent_instance.append(instance) | |||
| if len(consistent_instance) != 0: | |||
| rule = self.abducer.abduce_rules(consistent_instance) | |||
| rule = self.reasoner.abduce_rules(consistent_instance) | |||
| if rule != None: | |||
| rules.append(rule) | |||
| break | |||
| @@ -280,7 +280,7 @@ class HEDBridge(SimpleBridge): | |||
| else: | |||
| if equation_len == min_len: | |||
| print_log( | |||
| "Learned mapping is: " + str(self.abducer.mapping), | |||
| "Learned mapping is: " + str(self.reasoner.mapping), | |||
| logger="current", | |||
| ) | |||
| self.model.load(load_path="./weights/pretrain_weights.pth") | |||
| @@ -2,10 +2,14 @@ | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 1, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import sys\n", | |||
| "\n", | |||
| "sys.setrecursionlimit(10000)\n", | |||
| "\n", | |||
| "import torch\n", | |||
| "import numpy as np\n", | |||
| "import torch.nn as nn\n", | |||
| @@ -23,9 +27,17 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 2, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "11/16 20:43:38 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the HWF example.\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# Initialize logger and print basic information\n", | |||
| "print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n", | |||
| @@ -45,21 +57,12 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 3, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "# Initialize knowledge base and reasoner\n", | |||
| "class HWF_KB(KBBase):\n", | |||
| " def __init__(\n", | |||
| " self, \n", | |||
| " pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", | |||
| " prebuild_GKB=False,\n", | |||
| " GKB_len_list=[1, 3, 5, 7],\n", | |||
| " max_err=1e-3,\n", | |||
| " use_cache=True\n", | |||
| " ):\n", | |||
| " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||
| "\n", | |||
| " def _valid_candidate(self, formula):\n", | |||
| " if len(formula) % 2 == 0:\n", | |||
| @@ -79,8 +82,8 @@ | |||
| " formula = [mapping[f] for f in formula]\n", | |||
| " return eval(''.join(formula))\n", | |||
| "\n", | |||
| "kb = HWF_KB(prebuild_GKB=True)\n", | |||
| "abducer = ReasonerBase(kb, dist_func='confidence')" | |||
| "kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n", | |||
| "reasoner = ReasonerBase(kb, dist_func='confidence')" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -93,7 +96,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 4, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -106,7 +109,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -126,7 +129,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -146,12 +149,12 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 7, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Add metric\n", | |||
| "metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]" | |||
| "metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -164,7 +167,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 8, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -183,11 +186,11 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 9, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)" | |||
| "bridge = SimpleBridge(model=model, reasoner=reasoner, metric_list=metric_list)" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -200,11 +203,123 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 10, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [1/10] model loss is 0.16911\n", | |||
| "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [2/10] model loss is 0.17734\n", | |||
| "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [3/10] model loss is 0.01907\n", | |||
| "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [4/10] model loss is 0.01403\n", | |||
| "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [5/10] model loss is 0.00509\n", | |||
| "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [6/10] model loss is 0.00713\n", | |||
| "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [7/10] model loss is 0.00455\n", | |||
| "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [8/10] model loss is 0.00946\n", | |||
| "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [9/10] model loss is 0.00957\n", | |||
| "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [10/10] model loss is 0.00323\n", | |||
| "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n", | |||
| "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.997 hwf/semantics_accuracy: 0.985 \n", | |||
| "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n", | |||
| "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_1.pth\n", | |||
| "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [1/10] model loss is 0.00666\n", | |||
| "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", | |||
| "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [2/10] model loss is 0.01438\n", | |||
| "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [3/10] model loss is 0.00450\n", | |||
| "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [4/10] model loss is 0.00764\n", | |||
| "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [5/10] model loss is 0.00644\n", | |||
| "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [6/10] model loss is 0.00189\n", | |||
| "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [7/10] model loss is 0.00397\n", | |||
| "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [8/10] model loss is 0.00936\n", | |||
| "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [9/10] model loss is 0.00960\n", | |||
| "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [10/10] model loss is 0.00572\n", | |||
| "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.999 hwf/semantics_accuracy: 0.995 \n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_2.pth\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [1/10] model loss is 0.00180\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [2/10] model loss is 0.00615\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [3/10] model loss is 0.01000\n", | |||
| "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", | |||
| "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [4/10] model loss is 0.00415\n", | |||
| "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [5/10] model loss is 0.00960\n", | |||
| "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [6/10] model loss is 0.00697\n", | |||
| "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [7/10] model loss is 0.00977\n", | |||
| "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [8/10] model loss is 0.00734\n", | |||
| "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [9/10] model loss is 0.00922\n", | |||
| "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", | |||
| "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [10/10] model loss is 0.00982\n", | |||
| "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n", | |||
| "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.998 hwf/semantics_accuracy: 0.986 \n", | |||
| "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n", | |||
| "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_3.pth\n", | |||
| "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.994 hwf/semantics_accuracy: 0.970 \n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "bridge.train(train_data, epochs=3, batch_size=1000)\n", | |||
| "bridge.train(train_data, loops=3, segment_size=1000, save_interval=1, save_dir=weights_dir)\n", | |||
| "bridge.test(test_data)" | |||
| ] | |||
| } | |||
| @@ -2,10 +2,12 @@ | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 3, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import os.path as osp\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| @@ -13,21 +15,25 @@ | |||
| "\n", | |||
| "from abl.learning import BasicNN, ABLModel\n", | |||
| "from abl.bridge import SimpleBridge\n", | |||
| "from abl.evaluation import SymbolMetric, ABLMetric\n", | |||
| "from abl.utils import ABLLogger\n", | |||
| "from abl.evaluation import SymbolMetric, SemanticsMetric\n", | |||
| "from abl.utils import ABLLogger, print_log\n", | |||
| "\n", | |||
| "from models.nn import LeNet5\n", | |||
| "from examples.models.nn import LeNet5\n", | |||
| "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 4, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize logger\n", | |||
| "logger = ABLLogger.get_instance(\"abl\")" | |||
| "print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", | |||
| "\n", | |||
| "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", | |||
| "log_dir = ABLLogger.get_current_instance().log_dir\n", | |||
| "weights_dir = osp.join(log_dir, \"weights\")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -40,22 +46,19 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "# Initialize knowledge base and reasoner\n", | |||
| "class add_KB(KBBase):\n", | |||
| " def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", | |||
| " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||
| "\n", | |||
| " def logic_forward(self, nums):\n", | |||
| " return sum(nums)\n", | |||
| "\n", | |||
| "kb = add_KB(prebuild_GKB=True)\n", | |||
| "kb = add_KB(pseudo_label_list=list(range(10)))\n", | |||
| "\n", | |||
| "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", | |||
| "abducer = ReasonerBase(kb, dist_func=\"confidence\")" | |||
| "reasoner = ReasonerBase(kb, dist_func=\"confidence\")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -68,7 +71,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 6, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -81,7 +84,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 7, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -92,8 +95,6 @@ | |||
| " criterion,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " save_interval=1,\n", | |||
| " save_dir=logger.save_dir,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| ")" | |||
| @@ -109,7 +110,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 8, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -129,12 +130,12 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 9, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Add metric\n", | |||
| "metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" | |||
| "metric = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -147,7 +148,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 10, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -170,7 +171,7 @@ | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "bridge = SimpleBridge(model, abducer, metric)" | |||
| "bridge = SimpleBridge(model, reasoner, metric)" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -187,7 +188,7 @@ | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "bridge.train(train_data, epochs=5, batch_size=10000)\n", | |||
| "bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)\n", | |||
| "bridge.test(test_data)" | |||
| ] | |||
| } | |||
| @@ -208,7 +209,7 @@ | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.13" | |||
| "version": "3.8.16" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||