From 6a63f5e5f3de98119450a280289f092c068e40a1 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Tue, 14 Nov 2023 20:23:04 +0800 Subject: [PATCH 01/18] [ENH] add abstract data interface --- abl/bridge/base_bridge.py | 40 +- abl/bridge/simple_bridge.py | 151 ++++--- abl/dataset/__init__.py | 3 +- abl/dataset/bridge_dataset.py | 3 +- abl/dataset/classification_dataset.py | 3 +- abl/dataset/prediction_dataset.py | 56 +++ abl/dataset/regression_dataset.py | 3 +- abl/evaluation/__init__.py | 2 +- abl/evaluation/base_metric.py | 4 +- abl/evaluation/semantics_metric.py | 19 +- abl/evaluation/symbol_metric.py | 3 +- abl/learning/abl_model.py | 99 ++-- abl/learning/basic_nn.py | 80 +++- abl/structures/__init__.py | 2 + abl/structures/base_data_element.py | 629 ++++++++++++++++++++++++++ abl/structures/list_data.py | 321 +++++++++++++ abl/utils/__init__.py | 3 +- abl/utils/cache.py | 112 +++++ abl/utils/utils.py | 90 +++- 19 files changed, 1427 insertions(+), 196 deletions(-) create mode 100644 abl/dataset/prediction_dataset.py create mode 100644 abl/structures/__init__.py create mode 100644 abl/structures/base_data_element.py create mode 100644 abl/structures/list_data.py create mode 100644 abl/utils/cache.py diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index 03054f7..869ea39 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -1,52 +1,64 @@ from abc import ABCMeta, abstractmethod -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple, Union from ..learning import ABLModel from ..reasoning import ReasonerBase +from ..structures import ListData +DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] -class BaseBridge(metaclass=ABCMeta): +class BaseBridge(metaclass=ABCMeta): def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: if not isinstance(model, ABLModel): - raise TypeError("Expected an ABLModel") + raise TypeError( + "Expected an instance of ABLModel, but received type: {}".format( + type(model) + ) + ) if not isinstance(abducer, ReasonerBase): - raise TypeError("Expected an ReasonerBase") - + raise TypeError( + "Expected an instance of ReasonerBase, but received type: {}".format( + type(abducer) + ) + ) + self.model = model self.abducer = abducer @abstractmethod - def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: + def predict( + self, data_samples: ListData + ) -> Tuple[List[List[Any]], List[List[Any]]]: """Placeholder for predict labels from input.""" pass @abstractmethod - def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for abduce pseudo labels.""" + pass @abstractmethod - def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: + def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for map label space to symbol space.""" pass @abstractmethod - def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for map symbol space to label space.""" pass - + @abstractmethod - def train(self, train_data): + def train(self, train_data: Union[ListData, DataSet]): """Placeholder for train loop of ABductive Learning.""" pass @abstractmethod - def test(self, test_data): + def valid(self, valid_data: Union[ListData, DataSet]) -> None: """Placeholder for model test.""" pass @abstractmethod - def valid(self, valid_data): + def test(self, test_data: Union[ListData, DataSet]) -> None: """Placeholder for model validation.""" pass - \ No newline at end of file diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 3aecffd..9093bc1 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -1,13 +1,14 @@ -from ..learning import ABLModel -from ..reasoning import ReasonerBase -from ..evaluation import BaseMetric -from .base_bridge import BaseBridge -from typing import List, Union, Any, Tuple, Dict, Optional +import os.path as osp +from typing import Any, Dict, List, Optional, Tuple, Union + from numpy import ndarray -from torch.utils.data import DataLoader -from ..dataset import BridgeDataset -from ..utils.logger import print_log +from ..evaluation import BaseMetric +from ..learning import ABLModel +from ..reasoning import ReasonerBase +from ..structures import ListData +from ..utils import print_log +from .base_bridge import BaseBridge, DataSet class SimpleBridge(BaseBridge): @@ -20,85 +21,99 @@ class SimpleBridge(BaseBridge): super().__init__(model, abducer) self.metric_list = metric_list - def predict(self, X) -> Tuple[List[List[Any]], ndarray]: - pred_res = self.model.predict(X) - pred_idx, pred_prob = pred_res["label"], pred_res["prob"] - return pred_idx, pred_prob - + # TODO: add abducer.mapping to the property of SimpleBridge + + def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: + self.model.predict(data_samples) + return data_samples["pred_idx"], data_samples.get("pred_prob", None) + def abduce_pseudo_label( self, - pred_prob: ndarray, - pred_pseudo_label: List[List[Any]], - Y: List[Any], + data_samples: ListData, max_revision: int = -1, require_more_revision: int = 0, ) -> List[List[Any]]: - return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) + self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) + return data_samples["abduced_pseudo_label"] def idx_to_pseudo_label( - self, idx: List[List[Any]], mapping: Dict = None + self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: mapping = self.abducer.mapping - return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] + pred_idx = data_samples.pred_idx + data_samples.pred_pseudo_label = [ + [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx + ] + return data_samples["pred_pseudo_label"] def pseudo_label_to_idx( - self, pseudo_label: List[List[Any]], mapping: Dict = None + self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: mapping = self.abducer.remapping - return [ - [mapping[_pseudo_label] for _pseudo_label in sub_list] - for sub_list in pseudo_label + abduced_idx = [ + [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] + for sub_list in data_samples.abduced_pseudo_label ] + data_samples.abduced_idx = abduced_idx + return data_samples["abduced_idx"] + + def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: + data_samples = ListData() + + data_samples.X = X + data_samples.gt_pseudo_label = gt_pseudo_label + data_samples.Y = Y + + return data_samples def train( self, - train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], - epochs: int = 50, - batch_size: Union[int, float] = -1, + train_data: Union[ListData, DataSet], + loops: int = 50, + segment_size: Union[int, float] = -1, eval_interval: int = 1, + save_interval: Optional[int] = None, + save_dir: Optional[str] = None, ): - dataset = BridgeDataset(*train_data) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], - ) - - for epoch in range(epochs): - for seg_idx, (X, Z, Y) in enumerate(data_loader): - pred_idx, pred_prob = self.predict(X) - pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) - abduced_pseudo_label = self.abduce_pseudo_label( - pred_prob, pred_pseudo_label, Y - ) - abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) - loss = self.model.train(X, abduced_label) + if isinstance(train_data, ListData): + data_samples = train_data + else: + data_samples = self.data_preprocess(*train_data) + + for loop in range(loops): + for seg_idx in range((len(data_samples) - 1) // segment_size + 1): + sub_data_samples = data_samples[ + seg_idx * segment_size : (seg_idx + 1) * segment_size + ] + self.predict(sub_data_samples) + self.idx_to_pseudo_label(sub_data_samples) + self.abduce_pseudo_label(sub_data_samples) + self.pseudo_label_to_idx(sub_data_samples) + loss = self.model.train(sub_data_samples) print_log( - f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", + f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", logger="current", ) - if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: - print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") + if (loop + 1) % eval_interval == 0 or loop == loops - 1: + print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") self.valid(train_data) - def _valid(self, data_loader): - for X, Z, Y in data_loader: - pred_idx, pred_prob = self.predict(X) - pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) - data_samples = dict( - pred_idx=pred_idx, - pred_prob=pred_prob, - pred_pseudo_label=pred_pseudo_label, - gt_pseudo_label=Z, - Y=Y, - logic_forward=self.abducer.kb.logic_forward, - ) + if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): + print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") + self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) + + def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: + for seg_idx in range((len(data_samples) - 1) // batch_size + 1): + sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size] + self.predict(sub_data_samples) + self.idx_to_pseudo_label(sub_data_samples) + for metric in self.metric_list: - metric.process(data_samples) + metric.process(sub_data_samples) res = dict() for metric in self.metric_list: @@ -108,14 +123,12 @@ class SimpleBridge(BaseBridge): msg += k + f": {v:.3f} " print_log(msg, logger="current") - def valid(self, valid_data, batch_size=1000): - dataset = BridgeDataset(*valid_data) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], - ) - self._valid(data_loader) - - def test(self, test_data, batch_size=1000): - self.valid(test_data, batch_size) + def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None: + if not isinstance(valid_data, ListData): + data_samples = self.data_preprocess(*valid_data) + else: + data_samples = valid_data + self._valid(data_samples, batch_size=batch_size) + + def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None: + self.valid(test_data, batch_size=batch_size) diff --git a/abl/dataset/__init__.py b/abl/dataset/__init__.py index 6be0df1..a487476 100644 --- a/abl/dataset/__init__.py +++ b/abl/dataset/__init__.py @@ -1,3 +1,4 @@ from .bridge_dataset import BridgeDataset from .classification_dataset import ClassificationDataset -from .regression_dataset import RegressionDataset \ No newline at end of file +from .prediction_dataset import PredictionDataset +from .regression_dataset import RegressionDataset diff --git a/abl/dataset/bridge_dataset.py b/abl/dataset/bridge_dataset.py index bb0ce98..a7d32c5 100644 --- a/abl/dataset/bridge_dataset.py +++ b/abl/dataset/bridge_dataset.py @@ -1,5 +1,6 @@ +from typing import Any, List, Tuple + from torch.utils.data import Dataset -from typing import List, Any, Tuple class BridgeDataset(Dataset): diff --git a/abl/dataset/classification_dataset.py b/abl/dataset/classification_dataset.py index 28f9299..1663642 100644 --- a/abl/dataset/classification_dataset.py +++ b/abl/dataset/classification_dataset.py @@ -1,6 +1,7 @@ +from typing import Any, Callable, List, Tuple + import torch from torch.utils.data import Dataset -from typing import List, Any, Tuple, Callable class ClassificationDataset(Dataset): diff --git a/abl/dataset/prediction_dataset.py b/abl/dataset/prediction_dataset.py new file mode 100644 index 0000000..8e3c717 --- /dev/null +++ b/abl/dataset/prediction_dataset.py @@ -0,0 +1,56 @@ +from typing import Any, Callable, List, Tuple + +import torch +from torch.utils.data import Dataset + + +class PredictionDataset(Dataset): + def __init__(self, X: List[Any], transform: Callable[..., Any] = None): + """ + Initialize the dataset used for classification task. + + Parameters + ---------- + X : List[Any] + The input data. + transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version. Defaults to None. + """ + if not isinstance(X, list): + raise ValueError("X should be of type list.") + + self.X = X + self.transform = transform + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: + """ + Get the item at the given index. + + Parameters + ---------- + index : int + The index of the item to get. + + Returns + ------- + Tuple[Any, torch.Tensor] + A tuple containing the object and its label. + """ + if index >= len(self): + raise ValueError("index range error") + + x = self.X[index] + if self.transform is not None: + x = self.transform(x) + return x diff --git a/abl/dataset/regression_dataset.py b/abl/dataset/regression_dataset.py index 8cf136c..118ac65 100644 --- a/abl/dataset/regression_dataset.py +++ b/abl/dataset/regression_dataset.py @@ -1,6 +1,7 @@ +from typing import Any, List, Tuple + import torch from torch.utils.data import Dataset -from typing import List, Any, Tuple class RegressionDataset(Dataset): diff --git a/abl/evaluation/__init__.py b/abl/evaluation/__init__.py index a849d68..3106412 100644 --- a/abl/evaluation/__init__.py +++ b/abl/evaluation/__init__.py @@ -1,3 +1,3 @@ from .base_metric import BaseMetric -from .symbol_metric import SymbolMetric from .semantics_metric import SemanticsMetric +from .symbol_metric import SymbolMetric diff --git a/abl/evaluation/base_metric.py b/abl/evaluation/base_metric.py index 44364f8..e18f452 100644 --- a/abl/evaluation/base_metric.py +++ b/abl/evaluation/base_metric.py @@ -1,8 +1,8 @@ +import logging from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Sequence -from ..utils import print_log -import logging +from ..utils import print_log class BaseMetric(metaclass=ABCMeta): diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 3333daf..718cfea 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,25 +1,22 @@ from typing import Optional, Sequence + +from ..reasoning import BaseKB from .base_metric import BaseMetric -class ABLMetric(): - pass class SemanticsMetric(BaseMetric): - def __init__(self, prefix: Optional[str] = None) -> None: + def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: super().__init__(prefix) + self.kb = kb def process(self, data_samples: Sequence[dict]) -> None: - pred_pseudo_label = data_samples["pred_pseudo_label"] - gt_Y = data_samples["Y"] - logic_forward = data_samples["logic_forward"] - - for pred_z, y in zip(pred_pseudo_label, gt_Y): - if logic_forward(pred_z) == y: + for data_sample in data_samples: + if self.kb.check_equal(data_sample, data_sample["Y"][0]): self.results.append(1) else: self.results.append(0) - + def compute_metrics(self, results: list) -> dict: metrics = dict() metrics["semantics_accuracy"] = sum(results) / len(results) - return metrics \ No newline at end of file + return metrics diff --git a/abl/evaluation/symbol_metric.py b/abl/evaluation/symbol_metric.py index 3c0c216..c2d7938 100644 --- a/abl/evaluation/symbol_metric.py +++ b/abl/evaluation/symbol_metric.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence, Callable +from typing import Optional, Sequence + from .base_metric import BaseMetric diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index f1512fe..6685cc4 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -10,8 +10,10 @@ # # ================================================================# import pickle -from utils import flatten, reform_idx -from typing import List, Any, Optional +from typing import Any, Dict + +from ..structures import ListData +from ..utils import reform_idx class ABLModel: @@ -30,7 +32,7 @@ class ABLModel: Methods ------- - predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict + predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict Predict the labels and probabilities for the given data. valid(X: List[List[Any]], Y: List[Any]) -> float Calculate the accuracy score for the given data. @@ -42,20 +44,13 @@ class ABLModel: Load the model from a file. """ - def __init__(self, base_model) -> None: - self.classifier_list = [] - self.classifier_list.append(base_model) + def __init__(self, base_model: Any) -> None: + if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): + raise NotImplementedError("The base_model should implement fit and predict methods.") - if not ( - hasattr(base_model, "fit") - and hasattr(base_model, "predict") - and hasattr(base_model, "score") - ): - raise NotImplementedError( - "base_model should have fit, predict and score methods." - ) + self.base_model = base_model - def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: + def predict(self, data_samples: ListData) -> Dict: """ Predict the labels and probabilities for the given data. @@ -63,53 +58,30 @@ class ABLModel: ---------- X : List[List[Any]] The data to predict on. - mapping : Optional[dict], optional - A mapping dictionary to map labels to their original values, by default None. Returns ------- dict A dictionary containing the predicted labels and probabilities. """ - model = self.classifier_list[0] - data_X = flatten(X) + model = self.base_model + data_X = data_samples.flatten("X") if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) - prob = reform_idx(prob, X) + prob = reform_idx(prob, data_samples["X"]) else: prob = None label = model.predict(X=data_X) + label = reform_idx(label, data_samples["X"]) - if mapping is not None: - label = [mapping[y] for y in label] - - label = reform_idx(label, X) + data_samples.pred_idx = label + if prob is not None: + data_samples.pred_prob = prob return {"label": label, "prob": prob} - def valid(self, X: List[List[Any]], Y: List[Any]) -> float: - """ - Calculate the accuracy for the given data. - - Parameters - ---------- - X : List[List[Any]] - The data to calculate the accuracy on. - Y : List[Any] - The true labels for the given data. - - Returns - ------- - float - The accuracy score for the given data. - """ - data_X = flatten(X) - data_Y = flatten(Y) - score = self.classifier_list[0].score(X=data_X, y=data_Y) - return score - - def train(self, X: List[List[Any]], Y: List[Any]) -> float: + def train(self, data_samples: ListData) -> float: """ Train the model on the given data. @@ -125,29 +97,30 @@ class ABLModel: float The loss value of the trained model. """ - data_X = flatten(X) - data_Y = flatten(Y) - return self.classifier_list[0].fit(X=data_X, y=data_Y) + data_X = data_samples.flatten("X") + data_y = data_samples.flatten("abduced_idx") + return self.base_model.fit(X=data_X, y=data_y) def _model_operation(self, operation: str, *args, **kwargs): - model = self.classifier_list[0] + model = self.base_model if hasattr(model, operation): method = getattr(model, operation) method(*args, **kwargs) else: - try: - if not f"{operation}_path" in kwargs.keys(): - raise ValueError(f"'{operation}_path' should not be None") - if operation == "save": - with open(kwargs["save_path"], 'wb') as file: - pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) - elif operation == "load": - with open(kwargs["load_path"], 'rb') as file: - self.classifier_list[0] = pickle.load(file) - except: - raise NotImplementedError( - f"{type(model).__name__} object doesn't have the {operation} method" - ) + if not f"{operation}_path" in kwargs.keys(): + raise ValueError(f"'{operation}_path' should not be None") + else: + try: + if operation == "save": + with open(kwargs["save_path"], "wb") as file: + pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) + elif operation == "load": + with open(kwargs["load_path"], "rb") as file: + self.base_model = pickle.load(file) + except: + raise NotImplementedError( + f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." + ) def save(self, *args, **kwargs) -> None: """ diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index b9b0b36..b1da93c 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -10,14 +10,16 @@ # # ================================================================# -import torch +import os +import logging +from typing import Any, Callable, List, Optional, T, Tuple + import numpy +import torch from torch.utils.data import DataLoader -from ..utils.logger import print_log -from ..dataset import ClassificationDataset -import os -from typing import List, Any, T, Optional, Callable, Tuple +from ..dataset import ClassificationDataset, PredictionDataset +from ..utils.logger import print_log class BasicNN: @@ -99,9 +101,7 @@ class BasicNN: loss_value = self.train_epoch(data_loader) if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: if self.save_dir is None: - raise ValueError( - "save_dir should not be None if save_interval is not None." - ) + raise ValueError("save_dir should not be None if save_interval is not None.") self.save(epoch + 1) if self.stop_loss is not None and loss_value < self.stop_loss: break @@ -191,7 +191,7 @@ class BasicNN: with torch.no_grad(): results = [] - for data, _ in data_loader: + for data in data_loader: data = data.to(device) out = model(data) results.append(out) @@ -199,7 +199,10 @@ class BasicNN: return torch.cat(results, axis=0) def predict( - self, data_loader: DataLoader = None, X: List[Any] = None + self, + data_loader: DataLoader = None, + X: List[Any] = None, + test_transform: Callable[..., Any] = None, ) -> numpy.ndarray: """ Predict the class of the input data. @@ -218,11 +221,28 @@ class BasicNN: """ if data_loader is None: - data_loader = self._data_loader(X) + if test_transform is None: + print_log( + "Transform used in the training phase will be used in prediction.", + "current", + level=logging.WARNING, + ) + dataset = PredictionDataset(X, self.transform) + else: + dataset = PredictionDataset(X, test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=int(self.num_workers), + collate_fn=self.collate_fn, + ) return self._predict(data_loader).argmax(axis=1).cpu().numpy() def predict_proba( - self, data_loader: DataLoader = None, X: List[Any] = None + self, + data_loader: DataLoader = None, + X: List[Any] = None, + test_transform: Callable[..., Any] = None, ) -> numpy.ndarray: """ Predict the probability of each class for the input data. @@ -241,7 +261,21 @@ class BasicNN: """ if data_loader is None: - data_loader = self._data_loader(X) + if test_transform is None: + print_log( + "Transform used in the training phase will be used in prediction.", + "current", + level=logging.WARNING, + ) + dataset = PredictionDataset(X, self.transform) + else: + dataset = PredictionDataset(X, test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=int(self.num_workers), + collate_fn=self.collate_fn, + ) return self._predict(data_loader).softmax(axis=1).cpu().numpy() def _score(self, data_loader) -> Tuple[float, float]: @@ -313,15 +347,14 @@ class BasicNN: if data_loader is None: data_loader = self._data_loader(X, y) mean_loss, accuracy = self._score(data_loader) - print_log( - f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current" - ) + print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") return accuracy def _data_loader( self, X: List[Any], y: List[int] = None, + shuffle: bool = True, ) -> DataLoader: """ Generate a DataLoader for user-provided input and target data. @@ -350,7 +383,7 @@ class BasicNN: data_loader = DataLoader( dataset, batch_size=self.batch_size, - shuffle=True, + shuffle=shuffle, num_workers=int(self.num_workers), collate_fn=self.collate_fn, ) @@ -368,14 +401,13 @@ class BasicNN: The path to save the model, by default None. """ if self.save_dir is None and save_path is None: - raise ValueError( - "'save_dir' and 'save_path' should not be None simultaneously." - ) + raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.") - if save_path is None: - save_path = os.path.join( - self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth" - ) + if save_path is not None: + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) + else: + save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth") if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) diff --git a/abl/structures/__init__.py b/abl/structures/__init__.py new file mode 100644 index 0000000..52b5af3 --- /dev/null +++ b/abl/structures/__init__.py @@ -0,0 +1,2 @@ +from .base_data_element import BaseDataElement +from .list_data import ListData \ No newline at end of file diff --git a/abl/structures/base_data_element.py b/abl/structures/base_data_element.py new file mode 100644 index 0000000..03176d1 --- /dev/null +++ b/abl/structures/base_data_element.py @@ -0,0 +1,629 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Any, Iterator, Optional, Tuple, Type, Union + +import numpy as np +import torch + + +class BaseDataElement: + """A base data interface that supports Tensor-like and dict-like + operations. + + A typical data elements refer to predicted results or ground truth labels + on a task, such as predicted bboxes, instance masks, semantic + segmentation masks, etc. Because groundtruth labels and predicted results + often have similar properties (for example, the predicted bboxes and the + groundtruth bboxes), MMEngine uses the same abstract data interface to + encapsulate predicted results and groundtruth labels, and it is recommended + to use different name conventions to distinguish them, such as using + ``gt_instances`` and ``pred_instances`` to distinguish between labels and + predicted results. Additionally, we distinguish data elements at instance + level, pixel level, and label level. Each of these types has its own + characteristics. Therefore, MMEngine defines the base class + ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and + ``LabelData`` inheriting from ``BaseDataElement`` to represent different + types of ground truth labels or predictions. + + Another common data element is sample data. A sample data consists of input + data (such as an image) and its annotations and predictions. In general, + an image can have multiple types of annotations and/or predictions at the + same time (for example, both pixel-level semantic segmentation annotations + and instance-level detection bboxes annotations). All labels and + predictions of a training sample are often passed between Dataset, Model, + Visualizer, and Evaluator components. In order to simplify the interface + between components, we can treat them as a large data element and + encapsulate them. Such data elements are generally called XXDataSample in + the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` + allows `BaseDataElement` as its attribute. Such a class generally + encapsulates all the data of a sample in the algorithm library, and its + attributes generally are various types of data elements. For example, + MMDetection is assigned by the BaseDataElement to encapsulate all the data + elements of the sample labeling and prediction of a sample in the + algorithm library. + + The attributes in ``BaseDataElement`` are divided into two parts, + the ``metainfo`` and the ``data`` respectively. + + - ``metainfo``: Usually contains the + information about the image such as filename, + image_shape, pad_shape, etc. The attributes can be accessed or + modified by dict-like or object-like operations, such as + ``.`` (for data access and modification), ``in``, ``del``, + ``pop(str)``, ``get(str)``, ``metainfo_keys()``, + ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for + set or change key-value pairs in metainfo). + + - ``data``: Annotations or model predictions are + stored. The attributes can be accessed or modified by + dict-like or object-like operations, such as + ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, + ``values()``, ``items()``. Users can also apply tensor-like + methods to all :obj:`torch.Tensor` in the ``data_fields``, + such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, + ``to_tensor()``, ``.detach()``. + + Args: + metainfo (dict, optional): A dict contains the meta information + of single image, such as ``dict(img_shape=(512, 512, 3), + scale_factor=(1, 1, 1, 1))``. Defaults to None. + kwargs (dict, optional): A dict contains annotations of single image or + model predictions. Defaults to None. + + Examples: + >>> import torch + >>> from mmengine.structures import BaseDataElement + >>> gt_instances = BaseDataElement() + >>> bboxes = torch.rand((5, 4)) + >>> scores = torch.rand((5,)) + >>> img_id = 0 + >>> img_shape = (800, 1333) + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=img_id, img_shape=img_shape), + ... bboxes=bboxes, scores=scores) + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) + + >>> # new + >>> gt_instances1 = gt_instances.new( + ... metainfo=dict(img_id=1, img_shape=(640, 640)), + ... bboxes=torch.rand((5, 4)), + ... scores=torch.rand((5,))) + >>> gt_instances2 = gt_instances1.new() + + >>> # add and process property + >>> gt_instances = BaseDataElement() + >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) + >>> assert 'img_shape' in gt_instances.metainfo_keys() + >>> assert 'img_shape' in gt_instances + >>> assert 'img_shape' not in gt_instances.keys() + >>> assert 'img_shape' in gt_instances.all_keys() + >>> print(gt_instances.img_shape) + (100, 100) + >>> gt_instances.scores = torch.rand((5,)) + >>> assert 'scores' in gt_instances.keys() + >>> assert 'scores' in gt_instances + >>> assert 'scores' in gt_instances.all_keys() + >>> assert 'scores' not in gt_instances.metainfo_keys() + >>> print(gt_instances.scores) + tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) + >>> gt_instances.bboxes = torch.rand((5, 4)) + >>> assert 'bboxes' in gt_instances.keys() + >>> assert 'bboxes' in gt_instances + >>> assert 'bboxes' in gt_instances.all_keys() + >>> assert 'bboxes' not in gt_instances.metainfo_keys() + >>> print(gt_instances.bboxes) + tensor([[0.0900, 0.0424, 0.1755, 0.4469], + [0.8648, 0.0592, 0.3484, 0.0913], + [0.5808, 0.1909, 0.6165, 0.7088], + [0.5490, 0.4209, 0.9416, 0.2374], + [0.3652, 0.1218, 0.8805, 0.7523]]) + + >>> # delete and change property + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=0, img_shape=(640, 640)), + ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) + >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) + >>> gt_instances.img_shape # (1280, 1280) + >>> gt_instances.bboxes = gt_instances.bboxes * 2 + >>> gt_instances.get('img_shape', None) # (1280, 1280) + >>> gt_instances.get('bboxes', None) # 6x4 tensor + >>> del gt_instances.img_shape + >>> del gt_instances.bboxes + >>> assert 'img_shape' not in gt_instances + >>> assert 'bboxes' not in gt_instances + >>> gt_instances.pop('img_shape', None) # None + >>> gt_instances.pop('bboxes', None) # None + + >>> # Tensor-like + >>> cuda_instances = gt_instances.cuda() + >>> cuda_instances = gt_instances.to('cuda:0') + >>> cpu_instances = cuda_instances.cpu() + >>> cpu_instances = cuda_instances.to('cpu') + >>> fp16_instances = cuda_instances.to( + ... device=None, dtype=torch.float16, non_blocking=False, + ... copy=False, memory_format=torch.preserve_format) + >>> cpu_instances = cuda_instances.detach() + >>> np_instances = cpu_instances.numpy() + + >>> # print + >>> metainfo = dict(img_shape=(800, 1196, 3)) + >>> gt_instances = BaseDataElement( + ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + >>> sample = BaseDataElement(metainfo=metainfo, + ... gt_instances=gt_instances) + >>> print(sample) + + ) at 0x7f0fea49e130> + + >>> # inheritance + >>> class DetDataSample(BaseDataElement): + ... @property + ... def proposals(self): + ... return self._proposals + ... @proposals.setter + ... def proposals(self, value): + ... self.set_field(value, '_proposals', dtype=BaseDataElement) + ... @proposals.deleter + ... def proposals(self): + ... del self._proposals + ... @property + ... def gt_instances(self): + ... return self._gt_instances + ... @gt_instances.setter + ... def gt_instances(self, value): + ... self.set_field(value, '_gt_instances', + ... dtype=BaseDataElement) + ... @gt_instances.deleter + ... def gt_instances(self): + ... del self._gt_instances + ... @property + ... def pred_instances(self): + ... return self._pred_instances + ... @pred_instances.setter + ... def pred_instances(self, value): + ... self.set_field(value, '_pred_instances', + ... dtype=BaseDataElement) + ... @pred_instances.deleter + ... def pred_instances(self): + ... del self._pred_instances + >>> det_sample = DetDataSample() + >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) + >>> det_sample.proposals = proposals + >>> assert 'proposals' in det_sample + >>> assert det_sample.proposals == proposals + >>> del det_sample.proposals + >>> assert 'proposals' not in det_sample + >>> with self.assertRaises(AssertionError): + ... det_sample.proposals = torch.rand((5, 4)) + """ + + def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: + self._metainfo_fields: set = set() + self._data_fields: set = set() + + if metainfo is not None: + self.set_metainfo(metainfo=metainfo) + if kwargs: + self.set_data(kwargs) + + def set_metainfo(self, metainfo: dict) -> None: + """Set or change key-value pairs in ``metainfo_field`` by parameter + ``metainfo``. + + Args: + metainfo (dict): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + """ + assert isinstance( + metainfo, dict + ), f"metainfo should be a ``dict`` but got {type(metainfo)}" + meta = copy.deepcopy(metainfo) + for k, v in meta.items(): + self.set_field(name=k, value=v, field_type="metainfo", dtype=None) + + def set_data(self, data: dict) -> None: + """Set or change key-value pairs in ``data_field`` by parameter + ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. + """ + assert isinstance(data, dict), f"data should be a `dict` but got {data}" + for k, v in data.items(): + # Use `setattr()` rather than `self.set_field` to allow `set_data` + # to set property method. + setattr(self, k, v) + + def update(self, instance: "BaseDataElement") -> None: + """The update() method updates the BaseDataElement with the elements + from another BaseDataElement object. + + Args: + instance (BaseDataElement): Another BaseDataElement object for + update the current object. + """ + assert isinstance( + instance, BaseDataElement + ), f"instance should be a `BaseDataElement` but got {type(instance)}" + self.set_metainfo(dict(instance.metainfo_items())) + self.set_data(dict(instance.items())) + + def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": + """Return a new data element with same type. If ``metainfo`` and + ``data`` are None, the new data element will have same metainfo and + data. If metainfo or data is not None, the new result will overwrite it + with the input value. + + Args: + metainfo (dict, optional): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + Defaults to None. + kwargs (dict): A dict contains annotations of image or + model predictions. + + Returns: + BaseDataElement: A new data element with same type. + """ + new_data = self.__class__() + + if metainfo is not None: + new_data.set_metainfo(metainfo) + else: + new_data.set_metainfo(dict(self.metainfo_items())) + if kwargs: + new_data.set_data(kwargs) + else: + new_data.set_data(dict(self.items())) + return new_data + + def clone(self): + """Deep copy the current data element. + + Returns: + BaseDataElement: The copy of current data element. + """ + clone_data = self.__class__() + clone_data.set_metainfo(dict(self.metainfo_items())) + clone_data.set_data(dict(self.items())) + return clone_data + + def keys(self) -> list: + """ + Returns: + list: Contains all keys in data_fields. + """ + # We assume that the name of the attribute related to property is + # '_' + the name of the property. We use this rule to filter out + # private keys. + # TODO: Use a more robust way to solve this problem + private_keys = { + "_" + key + for key in self._data_fields + if isinstance(getattr(type(self), key, None), property) + } + return list(self._data_fields - private_keys) + + def metainfo_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo_fields. + """ + return list(self._metainfo_fields) + + def values(self) -> list: + """ + Returns: + list: Contains all values in data. + """ + return [getattr(self, k) for k in self.keys()] + + def metainfo_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo. + """ + return [getattr(self, k) for k in self.metainfo_keys()] + + def all_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo and data. + """ + return self.metainfo_keys() + self.keys() + + def all_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo and data. + """ + return self.metainfo_values() + self.values() + + def all_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``metainfo`` and ``data``. + """ + for k in self.all_keys(): + yield (k, getattr(self, k)) + + def items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``data``. + """ + for k in self.keys(): + yield (k, getattr(self, k)) + + def metainfo_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``metainfo``. + """ + for k in self.metainfo_keys(): + yield (k, getattr(self, k)) + + @property + def metainfo(self) -> dict: + """dict: A dict contains metainfo of current data element.""" + return dict(self.metainfo_items()) + + def __setattr__(self, name: str, value: Any): + """setattr is only used to set data.""" + if name in ("_metainfo_fields", "_data_fields"): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f"{name} has been used as a " + "private attribute, which is immutable." + ) + else: + self.set_field(name=name, value=value, field_type="data", dtype=None) + + def __delattr__(self, item: str): + """Delete the item in dataelement. + + Args: + item (str): The key to delete. + """ + if item in ("_metainfo_fields", "_data_fields"): + raise AttributeError( + f"{item} has been used as a " "private attribute, which is immutable." + ) + super().__delattr__(item) + if item in self._metainfo_fields: + self._metainfo_fields.remove(item) + elif item in self._data_fields: + self._data_fields.remove(item) + + # dict-like methods + __delitem__ = __delattr__ + + def get(self, key, default=None) -> Any: + """Get property in data and metainfo as the same as python.""" + # Use `getattr()` rather than `self.__dict__.get()` to allow getting + # properties. + return getattr(self, key, default) + + def pop(self, *args) -> Any: + """Pop property in data and metainfo as the same as python.""" + assert len(args) < 3, "``pop`` get more than 2 arguments" + name = args[0] + if name in self._metainfo_fields: + self._metainfo_fields.remove(args[0]) + return self.__dict__.pop(*args) + + elif name in self._data_fields: + self._data_fields.remove(args[0]) + return self.__dict__.pop(*args) + + # with default value + elif len(args) == 2: + return args[1] + else: + # don't just use 'self.__dict__.pop(*args)' for only popping key in + # metainfo or data + raise KeyError(f"{args[0]} is not contained in metainfo or data") + + def __contains__(self, item: str) -> bool: + """Whether the item is in dataelement. + + Args: + item (str): The key to inquire. + """ + return item in self._data_fields or item in self._metainfo_fields + + def set_field( + self, + value: Any, + name: str, + dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, + field_type: str = "data", + ) -> None: + """Special method for set union field, used as property.setter + functions.""" + assert field_type in ["metainfo", "data"] + if dtype is not None: + assert isinstance( + value, dtype + ), f"{value} should be a {dtype} but got {type(value)}" + + if field_type == "metainfo": + if name in self._data_fields: + raise AttributeError( + f"Cannot set {name} to be a field of metainfo " + f"because {name} is already a data field" + ) + self._metainfo_fields.add(name) + else: + if name in self._metainfo_fields: + raise AttributeError( + f"Cannot set {name} to be a field of data " + f"because {name} is already a metainfo field" + ) + self._data_fields.add(name) + super().__setattr__(name, value) + + # Tensor-like methods + def to(self, *args, **kwargs) -> "BaseDataElement": + """Apply same name function to all tensors in data_fields.""" + new_data = self.new() + for k, v in self.items(): + if hasattr(v, "to"): + v = v.to(*args, **kwargs) + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def cpu(self) -> "BaseDataElement": + """Convert all tensors to CPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cpu() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def cuda(self) -> "BaseDataElement": + """Convert all tensors to GPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cuda() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def npu(self) -> "BaseDataElement": + """Convert all tensors to NPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.npu() + data = {k: v} + new_data.set_data(data) + return new_data + + def mlu(self) -> "BaseDataElement": + """Convert all tensors to MLU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.mlu() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def detach(self) -> "BaseDataElement": + """Detach all tensors in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def numpy(self) -> "BaseDataElement": + """Convert all tensors to np.ndarray in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach().cpu().numpy() + data = {k: v} + new_data.set_data(data) + return new_data + + def to_tensor(self) -> "BaseDataElement": + """Convert all np.ndarray to tensor in data.""" + new_data = self.new() + for k, v in self.items(): + data = {} + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data[k] = v + elif isinstance(v, BaseDataElement): + v = v.to_tensor() + data[k] = v + new_data.set_data(data) + return new_data + + def to_dict(self) -> dict: + """Convert BaseDataElement to dict.""" + return { + k: v.to_dict() if isinstance(v, BaseDataElement) else v + for k, v in self.all_items() + } + + def __repr__(self) -> str: + """Represent the object.""" + + def _addindent(s_: str, num_spaces: int) -> str: + """This func is modified from `pytorch` https://github.com/pytorch/ + pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu + les/module.py#L29. + + Args: + s_ (str): The string to add spaces. + num_spaces (int): The num of space to add. + + Returns: + str: The string after add indent. + """ + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) # type: ignore + s = first + "\n" + s # type: ignore + return s # type: ignore + + def dump(obj: Any) -> str: + """Represent the object. + + Args: + obj (Any): The obj to represent. + + Returns: + str: The represented str. + """ + _repr = "" + if isinstance(obj, dict): + for k, v in obj.items(): + _repr += f"\n{k}: {_addindent(dump(v), 4)}" + elif isinstance(obj, BaseDataElement): + _repr += "\n\n META INFORMATION" + metainfo_items = dict(obj.metainfo_items()) + _repr += _addindent(dump(metainfo_items), 4) + _repr += "\n\n DATA FIELDS" + items = dict(obj.items()) + _repr += _addindent(dump(items), 4) + classname = obj.__class__.__name__ + _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" + else: + _repr += repr(obj) + return _repr + + return dump(self) diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py new file mode 100644 index 0000000..2571a13 --- /dev/null +++ b/abl/structures/list_data.py @@ -0,0 +1,321 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from collections.abc import Sized +from typing import Any, List, Union + +import numpy as np +import torch + +from ..utils import flatten as flatten_list +from ..utils import to_hashable +from .base_data_element import BaseDataElement + +BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] +LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] + +IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] + + +# Modified from +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa +class ListData(BaseDataElement): + """Data structure for instance-level annotations or predictions. + + Subclass of :class:`BaseDataElement`. All value in `data_fields` + should have the same length. This design refer to + https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 + ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value + in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, + and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. + + Examples: + >>> # custom data structure + >>> class TmpObject: + ... def __init__(self, tmp) -> None: + ... assert isinstance(tmp, list) + ... self.tmp = tmp + ... def __len__(self): + ... return len(self.tmp) + ... def __getitem__(self, item): + ... if isinstance(item, int): + ... if item >= len(self) or item < -len(self): # type:ignore + ... raise IndexError(f'Index {item} out of range!') + ... else: + ... # keep the dimension + ... item = slice(item, None, len(self)) + ... return TmpObject(self.tmp[item]) + ... @staticmethod + ... def cat(tmp_objs): + ... assert all(isinstance(results, TmpObject) for results in tmp_objs) + ... if len(tmp_objs) == 1: + ... return tmp_objs[0] + ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] + ... tmp_list = list(itertools.chain(*tmp_list)) + ... new_data = TmpObject(tmp_list) + ... return new_data + ... def __repr__(self): + ... return str(self.tmp) + >>> from mmengine.structures import ListData + >>> import numpy as np + >>> import torch + >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + >>> instance_data = ListData(metainfo=img_meta) + >>> 'img_shape' in instance_data + True + >>> instance_data.det_labels = torch.LongTensor([2, 3]) + >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) + >>> instance_data.bboxes = torch.rand((2, 4)) + >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) + >>> len(instance_data) + 2 + >>> print(instance_data) + + >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] + >>> sorted_results.det_scores + tensor([0.7000, 0.8000]) + >>> print(instance_data[instance_data.det_scores > 0.75]) + + >>> print(instance_data[instance_data.det_scores > 1]) + + >>> print(instance_data.cat([instance_data, instance_data])) + + """ + + def __setattr__(self, name: str, value: list): + """setattr is only used to set data. + + The value must have the attribute of `__len__` and have the same length + of `ListData`. + """ + if name in ("_metainfo_fields", "_data_fields"): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f"{name} has been used as a " + "private attribute, which is immutable." + ) + + else: + assert isinstance(value, list), "value must be of type `list`" + + if len(self) > 0: + assert len(value) == len(self), ( + "The length of " + f"values {len(value)} is " + "not consistent with " + "the length of this " + ":obj:`ListData` " + f"{len(self)}" + ) + super().__setattr__(name, value) + + __setitem__ = __setattr__ + + def __getitem__(self, item: IndexType) -> "ListData": + """ + Args: + item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, + :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): + Get the corresponding values according to item. + + Returns: + :obj:`ListData`: Corresponding values. + """ + assert isinstance(item, IndexType.__args__) + if isinstance(item, list): + item = np.array(item) + if isinstance(item, np.ndarray): + # The default int type of numpy is platform dependent, int32 for + # windows and int64 for linux. `torch.Tensor` requires the index + # should be int64, therefore we simply convert it to int64 here. + # More details in https://github.com/numpy/numpy/issues/9464 + item = item.astype(np.int64) if item.dtype == np.int32 else item + item = torch.from_numpy(item) + + if isinstance(item, str): + return getattr(self, item) + + if isinstance(item, int): + if item >= len(self) or item < -len(self): # type:ignore + raise IndexError(f"Index {item} out of range!") + else: + # keep the dimension + item = slice(item, None, len(self)) + + new_data = self.__class__(metainfo=self.metainfo) + if isinstance(item, torch.Tensor): + assert item.dim() == 1, ( + "Only support to get the" " values along the first dimension." + ) + if isinstance(item, BoolTypeTensor.__args__): + assert len(item) == len(self), ( + "The shape of the " + "input(BoolTensor) " + f"{len(item)} " + "does not match the shape " + "of the indexed tensor " + "in results_field " + f"{len(self)} at " + "first dimension." + ) + + for k, v in self.items(): + if isinstance(v, torch.Tensor): + new_data[k] = v[item] + elif isinstance(v, np.ndarray): + new_data[k] = v[item.cpu().numpy()] + elif isinstance(v, (str, list, tuple)) or ( + hasattr(v, "__getitem__") and hasattr(v, "cat") + ): + # convert to indexes from BoolTensor + if isinstance(item, BoolTypeTensor.__args__): + indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() + else: + indexes = item.cpu().numpy().tolist() + slice_list = [] + if indexes: + for index in indexes: + slice_list.append(slice(index, None, len(v))) + else: + slice_list.append(slice(None, 0, None)) + r_list = [v[s] for s in slice_list] + if isinstance(v, (str, list, tuple)): + new_value = r_list[0] + for r in r_list[1:]: + new_value = new_value + r + else: + new_value = v.cat(r_list) + new_data[k] = new_value + else: + raise ValueError( + f"The type of `{k}` is `{type(v)}`, which has no " + "attribute of `cat`, so it does not " + "support slice with `bool`" + ) + + else: + # item is a slice + for k, v in self.items(): + new_data[k] = v[item] + return new_data # type:ignore + + @staticmethod + def cat(instances_list: List["ListData"]) -> "ListData": + """Concat the instances of all :obj:`ListData` in the list. + + Note: To ensure that cat returns as expected, make sure that + all elements in the list must have exactly the same keys. + + Args: + instances_list (list[:obj:`ListData`]): A list + of :obj:`ListData`. + + Returns: + :obj:`ListData` + """ + assert all(isinstance(results, ListData) for results in instances_list) + assert len(instances_list) > 0 + if len(instances_list) == 1: + return instances_list[0] + + # metainfo and data_fields must be exactly the + # same for each element to avoid exceptions. + field_keys_list = [instances.all_keys() for instances in instances_list] + assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( + set(itertools.chain(*field_keys_list)) + ) == len(field_keys_list[0]), ( + "There are different keys in " + "`instances_list`, which may " + "cause the cat operation " + "to fail. Please make sure all " + "elements in `instances_list` " + "have the exact same key." + ) + + new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) + for k in instances_list[0].keys(): + values = [results[k] for results in instances_list] + v0 = values[0] + if isinstance(v0, torch.Tensor): + new_values = torch.cat(values, dim=0) + elif isinstance(v0, np.ndarray): + new_values = np.concatenate(values, axis=0) + elif isinstance(v0, (str, list, tuple)): + new_values = v0[:] + for v in values[1:]: + new_values += v + elif hasattr(v0, "cat"): + new_values = v0.cat(values) + else: + raise ValueError( + f"The type of `{k}` is `{type(v0)}` which has no " + "attribute of `cat`" + ) + new_data[k] = new_values + return new_data # type:ignore + + def flatten(self, item: IndexType) -> List: + """Flatten self[item]. + + Returns: + list: Flattened data fields. + """ + return flatten_list(self[item]) + + def elements_num(self, item: IndexType) -> int: + """int: The number of elements in self[item].""" + return len(self.flatten(item)) + + def to_tuple(self, item: IndexType) -> tuple: + """tuple: The data fields in self[item] converted to tuple.""" + return to_hashable(self[item]) + + def __len__(self) -> int: + """int: The length of ListData.""" + if len(self._data_fields) > 0: + one_element = next(iter(self._data_fields)) + return len(getattr(self, one_element)) + # return len(self.values()[0]) + else: + return 0 diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index 75c7990..526b50b 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,2 +1,3 @@ +from .cache import Cache from .logger import ABLLogger, print_log -from .utils import * \ No newline at end of file +from .utils import * diff --git a/abl/utils/cache.py b/abl/utils/cache.py new file mode 100644 index 0000000..f4b3b0c --- /dev/null +++ b/abl/utils/cache.py @@ -0,0 +1,112 @@ +import pickle +from os import PathLike +from pathlib import Path +from typing import Callable, Generic, Hashable, TypeVar, Union + +from .logger import print_log + +K = TypeVar("K") +T = TypeVar("T") +PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields + + +class Cache(Generic[K, T]): + def __init__( + self, + func: Callable[[K], T], + cache: bool, + cache_file: Union[None, str, PathLike], + key_func: Callable[[K], Hashable] = lambda x: x, + max_size: int = 4096, + ): + """Create cache + + :param func: Function this cache evaluates + :param cache: If true, do in memory caching. + :param cache_root: If not None, cache to files at the provided path. + :param key_func: Convert the key into a hashable object if needed + """ + self.func = func + self.key_func = key_func + self.cache = cache + if cache is True or cache_file is not None: + print_log("Caching is activated", logger="current") + self._init_cache(cache_file, max_size) + self.first = self.get_from_dict + else: + self.first = self.func + + def __getitem__(self, item: K, *args) -> T: + return self.first(item, *args) + + def invalidate(self): + """Invalidate entire cache.""" + self.cache_dict.clear() + if self.cache_file: + for p in self.cache_root.iterdir(): + p.unlink() + + def _init_cache(self, cache_file, max_size): + self.cache = True + self.cache_dict = dict() + + self.hits, self.misses, self.maxsize = 0, 0, max_size + self.full = False + self.root = [] # root of the circular doubly linked list + self.root[:] = [self.root, self.root, None, None] + + if cache_file is not None: + with open(cache_file, "rb") as f: + cache_dict_from_file = pickle.load(f) + self.maxsize += len(cache_dict_from_file) + print_log( + f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current" + ) + for cache_key, result in cache_dict_from_file.items(): + last = self.root[PREV] + link = [last, self.root, cache_key, result] + last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link + + def get(self, item: K, *args) -> T: + return self.first(item, *args) + + def get_from_dict(self, item: K, *args) -> T: + """Implements dict based cache.""" + cache_key = (self.key_func(item), *args) + link = self.cache_dict.get(cache_key) + if link is not None: + # Move the link to the front of the circular queue + link_prev, link_next, _key, result = link + link_prev[NEXT] = link_next + link_next[PREV] = link_prev + last = self.root[PREV] + last[NEXT] = self.root[PREV] = link + link[PREV] = last + link[NEXT] = self.root + self.hits += 1 + return result + self.misses += 1 + + result = self.func(item, *args) + + if self.full: + # Use the old root to store the new key and result. + oldroot = self.root + oldroot[KEY] = cache_key + oldroot[RESULT] = result + # Empty the oldest link and make it the new root. + self.root = oldroot[NEXT] + oldkey = self.root[KEY] + oldresult = self.root[RESULT] + self.root[KEY] = self.root[RESULT] = None + # Now update the cache dictionary. + del self.cache_dict[oldkey] + self.cache_dict[cache_key] = oldroot + else: + # Put result in a new link at the front of the queue. + last = self.root[PREV] + link = [last, self.root, cache_key, result] + last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link + if isinstance(self.maxsize, int): + self.full = len(self.cache_dict) >= self.maxsize + return result diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 485cf1a..8192bf9 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -1,6 +1,7 @@ -import numpy as np from itertools import chain +import numpy as np + def flatten(nested_list): """ @@ -15,6 +16,11 @@ def flatten(nested_list): ------- list A flattened version of the input list. + + Raises + ------ + TypeError + If the input object is not a list. """ if not isinstance(nested_list, list): raise TypeError("Input must be of type list.") @@ -41,6 +47,9 @@ def reform_idx(flattened_list, structured_list): list A reformed list that mimics the structure of structured_list. """ + # if not isinstance(flattened_list, list): + # raise TypeError("Input must be of type list.") + if not isinstance(structured_list[0], (list, tuple)): return flattened_list @@ -80,7 +89,7 @@ def hamming_dist(pred_pseudo_label, candidates): return np.sum(pred_pseudo_label != candidates, axis=1) -def confidence_dist(pred_prob, candidates_idx): +def confidence_dist(pred_prob, candidates): """ Compute the confidence distance between prediction probabilities and candidates. @@ -89,7 +98,7 @@ def confidence_dist(pred_prob, candidates_idx): pred_prob : list of numpy.ndarray Prediction probability distributions, each element is an ndarray representing the probability distribution of a particular prediction. - candidates_idx : list of list of int + candidates : list of list of int Index of candidate labels, each element is a list of indexes being considered as a candidate correction. @@ -99,8 +108,8 @@ def confidence_dist(pred_prob, candidates_idx): Confidence distances computed for each candidate. """ pred_prob = np.clip(pred_prob, 1e-9, 1) - _, cols = np.indices((len(candidates_idx), len(candidates_idx[0]))) - return 1 - np.prod(pred_prob[cols, candidates_idx], axis=1) + _, cols = np.indices((len(candidates), len(candidates[0]))) + return 1 - np.prod(pred_prob[cols, candidates], axis=1) def block_sample(X, Z, Y, sample_num, seg_idx): @@ -135,6 +144,34 @@ def block_sample(X, Z, Y, sample_num, seg_idx): return (data[start_idx:end_idx] for data in (X, Z, Y)) +def check_equal(a, b, max_err=0): + """ + Check whether two numbers a and b are equal within a maximum allowable error. + + Parameters + ---------- + a, b : int or float + The numbers to compare. + max_err : int or float, optional + The maximum allowable absolute difference between a and b for them to be considered equal. + Default is 0, meaning the numbers must be exactly equal. + + Returns + ------- + bool + True if a and b are equal within the allowable error, False otherwise. + + Raises + ------ + TypeError + If a or b are not of type int or float. + """ + if not (isinstance(a, (int, float)) and isinstance(b, (int, float))): + raise TypeError("Input values must be int or float.") + + return abs(a - b) <= max_err + + def to_hashable(x): """ Convert a nested list to a nested tuple so it is hashable. @@ -154,7 +191,8 @@ def to_hashable(x): return tuple(to_hashable(item) for item in x) return x -def restore_from_hashable(x): + +def hashable_to_list(x): """ Convert a nested tuple back to a nested list. @@ -174,6 +212,45 @@ def restore_from_hashable(x): return x +def calculate_revision_num(parameter, total_length): + """ + Convert a float parameter to an integer, based on a total length. + + Parameters + ---------- + parameter : int or float + The parameter to convert. If float, it should be between 0 and 1. + If int, it should be non-negative. If -1, it will be replaced with total_length. + total_length : int + The total length to calculate the parameter from if it's a fraction. + + Returns + ------- + int + The calculated parameter. + + Raises + ------ + TypeError + If parameter is not an int or a float. + ValueError + If parameter is a float not in [0, 1] or an int below 0. + """ + if not isinstance(parameter, (int, float)): + raise TypeError("Parameter must be of type int or float.") + + if parameter == -1: + return total_length + elif isinstance(parameter, float): + if not (0 <= parameter <= 1): + raise ValueError("If parameter is a float, it must be between 0 and 1.") + return round(total_length * parameter) + else: + if parameter < 0: + raise ValueError("If parameter is an int, it must be non-negative.") + return parameter + + if __name__ == "__main__": A = np.array( [ @@ -227,4 +304,5 @@ if __name__ == "__main__": ) B = [[0, 9, 3], [0, 11, 4]] + print(ori_confidence_dist(A, B)) print(confidence_dist(A, B)) From 3bfd2c8d11714baa4653144e9625d68bf61c8f93 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Tue, 14 Nov 2023 22:29:13 +0800 Subject: [PATCH 02/18] [ENH] change Cache to decorator abl_cache --- abl/utils/__init__.py | 2 +- abl/utils/cache.py | 31 ++++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index 526b50b..bbd0d81 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,3 +1,3 @@ -from .cache import Cache +from .cache import Cache, abl_cache from .logger import ABLLogger, print_log from .utils import * diff --git a/abl/utils/cache.py b/abl/utils/cache.py index f4b3b0c..dbf60d0 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -1,6 +1,5 @@ import pickle from os import PathLike -from pathlib import Path from typing import Callable, Generic, Hashable, TypeVar, Union from .logger import print_log @@ -14,8 +13,8 @@ class Cache(Generic[K, T]): def __init__( self, func: Callable[[K], T], - cache: bool, - cache_file: Union[None, str, PathLike], + cache: bool = True, + cache_file: Union[None, str, PathLike] = None, key_func: Callable[[K], Hashable] = lambda x: x, max_size: int = 4096, ): @@ -67,11 +66,12 @@ class Cache(Generic[K, T]): link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - def get(self, item: K, *args) -> T: - return self.first(item, *args) + def get(self, obj, item: K, *args) -> T: + return self.first(obj, item, *args) - def get_from_dict(self, item: K, *args) -> T: + def get_from_dict(self, obj, item: K, *args) -> T: """Implements dict based cache.""" + # result = self.func(obj, item, *args) cache_key = (self.key_func(item), *args) link = self.cache_dict.get(cache_key) if link is not None: @@ -87,7 +87,7 @@ class Cache(Generic[K, T]): return result self.misses += 1 - result = self.func(item, *args) + result = self.func(obj, item, *args) if self.full: # Use the old root to store the new key and result. @@ -110,3 +110,20 @@ class Cache(Generic[K, T]): if isinstance(self.maxsize, int): self.full = len(self.cache_dict) >= self.maxsize return result + + +def abl_cache( + cache: bool = True, + cache_file: Union[None, str, PathLike] = None, + key_func: Callable[[K], Hashable] = lambda x: x, + max_size: int = 4096, +): + def decorator(func): + cache_instance = Cache(func, cache, cache_file, key_func, max_size) + + def wrapper(self, *args, **kwargs): + return cache_instance.get(self, *args, **kwargs) + + return wrapper + + return decorator From b9f507910a35b0dc06fb5518bb076dc0e058c239 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Wed, 15 Nov 2023 14:04:23 +0800 Subject: [PATCH 03/18] add __init__ in reasoning --- abl/reasoning/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/abl/reasoning/__init__.py b/abl/reasoning/__init__.py index 8930758..f2ff627 100644 --- a/abl/reasoning/__init__.py +++ b/abl/reasoning/__init__.py @@ -1,2 +1,2 @@ -from .reasoner import ReasonerBase -from .kb import KBBase, prolog_KB \ No newline at end of file +from .kb import KBBase, GroundKB, PrologKB +from .reasoner import ReasonerBase \ No newline at end of file From e2b0b330afd27fba6d9613ea1cf0e3adf27edf00 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Wed, 15 Nov 2023 14:07:42 +0800 Subject: [PATCH 04/18] [MNT] support ListData in reasoning --- abl/evaluation/semantics_metric.py | 4 +- abl/reasoning/kb.py | 2 +- abl/reasoning/reasoner.py | 26 ++++---- abl/utils/utils.py | 4 +- examples/mnist_add/mnist_add_example.ipynb | 71 ++++++++++++++++------ 5 files changed, 70 insertions(+), 37 deletions(-) diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 718cfea..ae7aac8 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,11 +1,11 @@ from typing import Optional, Sequence -from ..reasoning import BaseKB +from ..reasoning import KBBase from .base_metric import BaseMetric class SemanticsMetric(BaseMetric): - def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: + def __init__(self, kb: KBBase = None, prefix: Optional[str] = None) -> None: super().__init__(prefix) self.kb = kb diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 37ba5b6..b626504 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -9,7 +9,7 @@ from functools import lru_cache import numpy as np import pyswip -from abl.utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable +from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable class KBBase(ABC): diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 9cc24f0..686e9dd 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,6 +1,6 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from abl.utils.utils import ( +from ..utils.utils import ( confidence_dist, flatten, reform_idx, @@ -191,7 +191,7 @@ class ReasonerBase: return max_revision def abduce( - self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 + self, data_sample, max_revision=-1, require_more_revision=0 ): """ Perform abductive reasoning on the given prediction data. @@ -219,9 +219,13 @@ class ReasonerBase: A revised pseudo label through abductive reasoning, which is consistent with the knowledge base. """ - symbol_num = len(flatten(pred_pseudo_label)) + symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - + + pred_pseudo_label = data_sample.pred_pseudo_label[0] + pred_prob = data_sample.pred_prob[0] + y = data_sample.Y[0] + if self.use_zoopt: solution = self.zoopt_get_solution( symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num @@ -237,20 +241,18 @@ class ReasonerBase: return candidate def batch_abduce( - self, pred_probs, pred_pseudo_labels, Ys, max_revision=-1, require_more_revision=0 + self, data_samples, max_revision=-1, require_more_revision=0 ): """ Perform abductive reasoning on the given prediction data in batches. For detailed information, refer to `abduce`. """ - return [ - self.abduce( - pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision - ) - for pred_prob, pred_pseudo_label, Y in zip( - pred_probs, pred_pseudo_labels, Ys - ) + abduced_pseudo_label = [ + self.abduce(data_sample, max_revision, require_more_revision) + for data_sample in data_samples ] + data_samples.abduced_pseudo_label = abduced_pseudo_label + return abduced_pseudo_label # def _batch_abduce_helper(self, args): # z, prob, y, max_revision, require_more_revision = args diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 8192bf9..1480045 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -192,7 +192,7 @@ def to_hashable(x): return x -def hashable_to_list(x): +def restore_from_hashable(x): """ Convert a nested tuple back to a nested list. @@ -208,7 +208,7 @@ def hashable_to_list(x): otherwise the original input. """ if isinstance(x, tuple): - return [hashable_to_list(item) for item in x] + return [restore_from_hashable(item) for item in x] return x diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 146bd88..fc06d34 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -13,16 +13,16 @@ "\n", "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", - "from abl.evaluation import SymbolMetric, ABLMetric\n", + "from abl.evaluation import SymbolMetric\n", "from abl.utils import ABLLogger\n", "\n", - "from models.nn import LeNet5\n", + "from examples.models.nn import LeNet5\n", "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -40,19 +40,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Initialize knowledge base and abducer\n", "class add_KB(KBBase):\n", - " def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", - " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", + " def __init__(self, pseudo_label_list=list(range(10)), max_err=0, use_cache=True):\n", + " super().__init__(pseudo_label_list, max_err, use_cache)\n", "\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", - "kb = add_KB(prebuild_GKB=True)\n", + "kb = add_KB()\n", "\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", "abducer = ReasonerBase(kb, dist_func=\"confidence\")" @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,6 @@ " optimizer,\n", " device,\n", " save_interval=1,\n", - " save_dir=logger.save_dir,\n", " batch_size=32,\n", " num_epochs=1,\n", ")" @@ -109,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -129,12 +128,12 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Add metric\n", - "metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" + "metric = [SymbolMetric(prefix=\"mnist_add\")]" ] }, { @@ -147,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -166,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -181,15 +180,47 @@ "### Train and Test" ] }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/15 13:36:00 - abl - WARNING - Transform used in the training phase will be used in prediction.\n" + ] + }, + { + "ename": "TypeError", + "evalue": "Input must be of type list.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb 单元格 17\u001b[0m line \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m bridge\u001b[39m.\u001b[39;49mtrain(train_data, loops\u001b[39m=\u001b[39;49m\u001b[39m5\u001b[39;49m, segment_size\u001b[39m=\u001b[39;49m\u001b[39m10000\u001b[39;49m)\n\u001b[1;32m 2\u001b[0m bridge\u001b[39m.\u001b[39mtest(test_data)\n", + "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:92\u001b[0m, in \u001b[0;36mSimpleBridge.train\u001b[0;34m(self, train_data, loops, segment_size, eval_interval, save_interval, save_dir)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict(sub_data_samples)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39midx_to_pseudo_label(sub_data_samples)\n\u001b[0;32m---> 92\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce_pseudo_label(sub_data_samples)\n\u001b[1;32m 93\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpseudo_label_to_idx(sub_data_samples)\n\u001b[1;32m 94\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mtrain(sub_data_samples)\n", + "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:36\u001b[0m, in \u001b[0;36mSimpleBridge.abduce_pseudo_label\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce_pseudo_label\u001b[39m(\n\u001b[1;32m 31\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 32\u001b[0m data_samples: ListData,\n\u001b[1;32m 33\u001b[0m max_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\n\u001b[1;32m 34\u001b[0m require_more_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m,\n\u001b[1;32m 35\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[List[Any]]:\n\u001b[0;32m---> 36\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabducer\u001b[39m.\u001b[39;49mbatch_abduce(data_samples, max_revision, require_more_revision)\n\u001b[1;32m 37\u001b[0m \u001b[39mreturn\u001b[39;00m data_samples[\u001b[39m\"\u001b[39m\u001b[39mabduced_pseudo_label\u001b[39m\u001b[39m\"\u001b[39m]\n", + "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:246\u001b[0m, in \u001b[0;36mReasonerBase.batch_abduce\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 247\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", + "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:247\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[0;32m--> 247\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", + "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:222\u001b[0m, in \u001b[0;36mReasonerBase.abduce\u001b[0;34m(self, pred_prob, pred_pseudo_label, y, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce\u001b[39m(\n\u001b[1;32m 194\u001b[0m \u001b[39mself\u001b[39m, pred_prob, pred_pseudo_label, y, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 195\u001b[0m ):\n\u001b[1;32m 196\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data.\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[39m knowledge base.\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 222\u001b[0m symbol_num \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(flatten(pred_pseudo_label))\n\u001b[1;32m 223\u001b[0m max_revision_num \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_max_revision_num(max_revision, symbol_num)\n\u001b[1;32m 225\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39muse_zoopt:\n", + "File \u001b[0;32m~/ABL-Package/abl/utils/utils.py:26\u001b[0m, in \u001b[0;36mflatten\u001b[0;34m(nested_list)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[39mFlattens a nested list.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m If the input object is not a list.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list, \u001b[39mlist\u001b[39m):\n\u001b[0;32m---> 26\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput must be of type list.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m nested_list \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list[\u001b[39m0\u001b[39m], (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)):\n\u001b[1;32m 29\u001b[0m \u001b[39mreturn\u001b[39;00m nested_list\n", + "\u001b[0;31mTypeError\u001b[0m: Input must be of type list." + ] + } + ], + "source": [ + "bridge.train(train_data, loops=5, segment_size=10000)\n", + "bridge.test(test_data)" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "bridge.train(train_data, epochs=5, batch_size=10000)\n", - "bridge.test(test_data)" - ] + "source": [] } ], "metadata": { From c38787969441fc18f6b7403e35760aa8c79d74bb Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Wed, 15 Nov 2023 14:12:28 +0800 Subject: [PATCH 05/18] [MNT] support abl_cache --- abl/reasoning/kb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index b626504..e1df939 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -10,7 +10,7 @@ import numpy as np import pyswip from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable - +from ..utils.cache import abl_cache class KBBase(ABC): """ @@ -179,7 +179,7 @@ class KBBase(ABC): candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) return candidates - @lru_cache(maxsize=4096) + @abl_cache(max_size=4096) def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): """ `_abduce_by_search` with cache. From c95faf043d3c2a2ad6fcccf85cf7c04573c3d306 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 15 Nov 2023 15:56:42 +0800 Subject: [PATCH 06/18] [ENH] integrate choice of cache in to abl_cache --- abl/reasoning/kb.py | 201 +++++++++++++++++++++++--------------------- abl/utils/cache.py | 43 ++++------ 2 files changed, 121 insertions(+), 123 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index e1df939..3e5ca9e 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -12,6 +12,7 @@ import pyswip from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable from ..utils.cache import abl_cache + class KBBase(ABC): """ Base class for knowledge base. @@ -21,35 +22,36 @@ class KBBase(ABC): pseudo_label_list : list List of possible pseudo labels. max_err : float, optional - The upper tolerance limit when comparing the similarity between a candidate's logical - result. This is only applicable when the logical result is of a numerical type. - This is particularly relevant for regression problems where exact matches might not be - feasible. Defaults to 1e-10. + The upper tolerance limit when comparing the similarity between a candidate's logical + result. This is only applicable when the logical result is of a numerical type. + This is particularly relevant for regression problems where exact matches might not be + feasible. Defaults to 1e-10. use_cache : bool, optional - Whether to use a cache for previously abduced candidates to speed up subsequent + Whether to use a cache for previously abduced candidates to speed up subsequent operations. Defaults to True. - + Notes ----- - Users should inherit from this base class to build their own knowledge base. For the - user-build KB (an inherited subclass), it's only required for the user to provide the - `pseudo_label_list` and override the `logic_forward` function (specifying how to - perform logical reasoning). After that, other operations (e.g. how to perform abductive - reasoning) will be automatically set up. + Users should inherit from this base class to build their own knowledge base. For the + user-build KB (an inherited subclass), it's only required for the user to provide the + `pseudo_label_list` and override the `logic_forward` function (specifying how to + perform logical reasoning). After that, other operations (e.g. how to perform abductive + reasoning) will be automatically set up. """ + def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True): if not isinstance(pseudo_label_list, list): raise TypeError("pseudo_label_list should be list") self.pseudo_label_list = pseudo_label_list self.max_err = max_err - self.use_cache = use_cache + self.use_cache = use_cache @abstractmethod def logic_forward(self, pseudo_label): """ - How to perform (deductive) logical reasoning, i.e. matching each pseudo label to + How to perform (deductive) logical reasoning, i.e. matching each pseudo label to their logical result. Users are required to provide this. - + Parameters ---------- pred_pseudo_label : List[Any] @@ -70,23 +72,22 @@ class KBBase(ABC): max_revision_num : int The upper limit on the number of revisions. require_more_revision : int, optional - Specifies additional number of revisions permitted beyond the minimum required. + Specifies additional number of revisions permitted beyond the minimum required. Defaults to 0. Returns ------- List[List[Any]] - A list of candidates, i.e. revised pseudo labels that are consistent with the + A list of candidates, i.e. revised pseudo labels that are consistent with the knowledge base. """ - if self.use_cache: - return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), - to_hashable(y), - max_revision_num, require_more_revision) - else: - return self._abduce_by_search(pred_pseudo_label, y, - max_revision_num, require_more_revision) - + # if self.use_cache: + # return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), + # to_hashable(y), + # max_revision_num, require_more_revision) + # else: + return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) + def _check_equal(self, logic_result, y): """ Check whether the logical result of a candidate is equal to the ground truth @@ -94,12 +95,12 @@ class KBBase(ABC): """ if logic_result == None: return False - + if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): return abs(logic_result - y) <= self.max_err else: return logic_result == y - + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): """ Revise the predicted pseudo label at specified index positions. @@ -125,7 +126,7 @@ class KBBase(ABC): def _revision(self, revision_num, pred_pseudo_label, y): """ - For a specified number of pseudo label to revise, iterate through all possible + For a specified number of pseudo label to revise, iterate through all possible indices to find any candidates that are consistent with the knowledge base. """ new_candidates = [] @@ -136,12 +137,13 @@ class KBBase(ABC): new_candidates.extend(candidates) return new_candidates - def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + @abl_cache(max_size=4096) + def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): """ - Perform abductive reasoning by exhastive search. Specifically, begin with 0 and - continuously increase the number of pseudo labels to revise, until candidates + Perform abductive reasoning by exhastive search. Specifically, begin with 0 and + continuously increase the number of pseudo labels to revise, until candidates that are consistent with the knowledge base are found. - + Parameters ---------- pred_pseudo_label : List[Any] @@ -151,16 +153,16 @@ class KBBase(ABC): max_revision_num : int The upper limit on the number of revisions. require_more_revision : int - If larger than 0, then after having found any candidates consistent with the - knowledge base, continue to increase the number pseudo labels to revise to + If larger than 0, then after having found any candidates consistent with the + knowledge base, continue to increase the number pseudo labels to revise to get more possible consistent candidates. Returns ------- List[List[Any]] - A list of candidates, i.e. revised pseudo label that are consistent with the + A list of candidates, i.e. revised pseudo label that are consistent with the knowledge base. - """ + """ candidates = [] for revision_num in range(len(pred_pseudo_label) + 1): if revision_num == 0 and self._check_equal(self.logic_forward(pred_pseudo_label), y): @@ -173,20 +175,22 @@ class KBBase(ABC): if revision_num >= max_revision_num: return [] - for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): + for revision_num in range( + min_revision_num + 1, min_revision_num + require_more_revision + 1 + ): if revision_num > max_revision_num: return candidates candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) return candidates - - @abl_cache(max_size=4096) - def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): - """ - `_abduce_by_search` with cache. - """ - pred_pseudo_label = restore_from_hashable(pred_pseudo_label) - y = restore_from_hashable(y) - return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) + + # @abl_cache(max_size=4096) + # def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + # """ + # `_abduce_by_search` with cache. + # """ + # pred_pseudo_label = restore_from_hashable(pred_pseudo_label) + # y = restore_from_hashable(y) + # return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) def __repr__(self): return ( @@ -195,13 +199,13 @@ class KBBase(ABC): f"max_err={self.max_err!r}, " f"use_cache={self.use_cache!r}." ) - - + + class GroundKB(KBBase): """ - Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon - class initialization, storing all potential candidates along with their respective - logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. + Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon + class initialization, storing all potential candidates along with their respective + logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. Parameters ---------- @@ -211,15 +215,16 @@ class GroundKB(KBBase): List of possible lengths of pseudo label. max_err : float, optional Refer to class `KBBase`. - + Notes ----- - Users can also inherit from this class to build their own knowledge base. Similar - to `KBBase`, users are only required to provide the `pseudo_label_list` and override + Users can also inherit from this class to build their own knowledge base. Similar + to `KBBase`, users are only required to provide the `pseudo_label_list` and override the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. - After that, other operations (e.g. auto-construction of GKB, and how to perform + After that, other operations (e.g. auto-construction of GKB, and how to perform abductive reasoning) will be automatically set up. """ + def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10): super().__init__(pseudo_label_list, max_err) if not isinstance(GKB_len_list, list): @@ -229,7 +234,6 @@ class GroundKB(KBBase): X, Y = self._get_GKB() for x, y in zip(X, Y): self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) - def _get_XY_list(self, args): pre_x, post_x_it = args[0], args[1] @@ -259,21 +263,21 @@ class GroundKB(KBBase): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) - if Y and isinstance(Y[0], (int, float)): + if Y and isinstance(Y[0], (int, float)): X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) return X, Y - + def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): """ - Perform abductive reasoning by directly retrieving consistent candidates from - the prebuilt GKB. In this way, the time-consuming exhaustive search can be + Perform abductive reasoning by directly retrieving consistent candidates from + the prebuilt GKB. In this way, the time-consuming exhaustive search can be avoided. - This is an overridden function. For more information about the parameters and + This is an overridden function. For more information about the parameters and returns, refer to the function of the same name in class `KBBase`. """ if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list: return [] - + all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) if len(all_candidates) == 0: return [] @@ -284,29 +288,30 @@ class GroundKB(KBBase): idxs = np.where(cost_list <= revision_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates - + def _find_candidate_GKB(self, pred_pseudo_label, y): """ - Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, - return all candidates whose logical results fall within the + Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, + return all candidates whose logical results fall within the [y - max_err, y + max_err] range. """ if isinstance(y, (int, float)): potential_candidates = self.GKB[len(pred_pseudo_label)] key_list = list(potential_candidates.keys()) - + low_key = bisect.bisect_left(key_list, y - self.max_err) high_key = bisect.bisect_right(key_list, y + self.max_err) - all_candidates = [candidate - for key in key_list[low_key:high_key] - for candidate in potential_candidates[key]] + all_candidates = [ + candidate + for key in key_list[low_key:high_key] + for candidate in potential_candidates[key] + ] return all_candidates - + else: return self.GKB[len(pred_pseudo_label)][y] - - + def __repr__(self): return ( f"{self.__class__.__name__} is a KB with " @@ -321,78 +326,80 @@ class GroundKB(KBBase): class PrologKB(KBBase): """ Knowledge base provided by a Prolog (.pl) file. - + Parameters ---------- pseudo_label_list : list Refer to class `KBBase`. - pl_file : - Prolog file containing the KB. + pl_file : + Prolog file containing the KB. max_err : float, optional Refer to class `KBBase`. - + Notes ----- - Users can instantiate this class to build their own knowledge base. During the + Users can instantiate this class to build their own knowledge base. During the instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`. - To use the default logic forward and abductive reasoning methods in this class, in the - Prolog (.pl) file, there needs to be a rule which is strictly formatted as + To use the default logic forward and abductive reasoning methods in this class, in the + Prolog (.pl) file, there needs to be a rule which is strictly formatted as `logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`. - For specifics, refer to the `logic_forward` and `get_query_string` functions in this + For specifics, refer to the `logic_forward` and `get_query_string` functions in this class. Users are also welcome to override related functions for more flexible support. """ + def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list) self.pl_file = pl_file self.prolog = pyswip.Prolog() - + if not os.path.exists(self.pl_file): raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") self.prolog.consult(self.pl_file) def logic_forward(self, pseudo_labels): """ - Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the - returned `Res` as the logical results. To use this default function, there must be - a Prolog `log_forward` method in the pl file to perform logical. reasoning. + Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the + returned `Res` as the logical results. To use this default function, there must be + a Prolog `log_forward` method in the pl file to perform logical. reasoning. Otherwise, users would override this function. """ - result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] - if result == 'true': + result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"] + if result == "true": return True - elif result == 'false': + elif result == "false": return False return result - + def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx): import re + revision_pred_pseudo_label = pred_pseudo_label.copy() revision_pred_pseudo_label = flatten(revision_pred_pseudo_label) - + for idx in revision_idx: - revision_pred_pseudo_label[idx] = 'P' + str(idx) + revision_pred_pseudo_label[idx] = "P" + str(idx) revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) - + regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) - + def get_query_string(self, pred_pseudo_label, y, revision_idx): """ - Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set - the returned `Revise_labels` together with the kept labels as the candidates. This is + Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set + the returned `Revise_labels` together with the kept labels as the candidates. This is a default fuction for demo, users would override this function to adapt to their own - Prolog file. + Prolog file. """ query_string = "logic_forward(" query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) key_is_none_flag = y is None or (type(y) == list and y[0] is None) query_string += ",%s)." % y if not key_is_none_flag else ")." return query_string - + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): """ Revise the predicted pseudo label at specified index positions by querying Prolog. - This is an overridden function. For more information about the parameters, refer to + This is an overridden function. For more information about the parameters, refer to the function of the same name in class `KBBase`. """ candidates = [] @@ -414,4 +421,4 @@ class PrologKB(KBBase): f"pseudo_label_list={self.pseudo_label_list!r}, " f"defined by " f"Prolog file {self.pl_file!r}." - ) \ No newline at end of file + ) diff --git a/abl/utils/cache.py b/abl/utils/cache.py index dbf60d0..a342d0a 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -3,6 +3,7 @@ from os import PathLike from typing import Callable, Generic, Hashable, TypeVar, Union from .logger import print_log +from .utils import to_hashable K = TypeVar("K") T = TypeVar("T") @@ -13,7 +14,6 @@ class Cache(Generic[K, T]): def __init__( self, func: Callable[[K], T], - cache: bool = True, cache_file: Union[None, str, PathLike] = None, key_func: Callable[[K], Hashable] = lambda x: x, max_size: int = 4096, @@ -27,23 +27,15 @@ class Cache(Generic[K, T]): """ self.func = func self.key_func = key_func - self.cache = cache - if cache is True or cache_file is not None: - print_log("Caching is activated", logger="current") - self._init_cache(cache_file, max_size) - self.first = self.get_from_dict - else: - self.first = self.func - def __getitem__(self, item: K, *args) -> T: - return self.first(item, *args) + self._init_cache(cache_file, max_size) + + def __getitem__(self, obj, *args) -> T: + return self.get_from_dict(obj, *args) - def invalidate(self): + def clear_cache(self): """Invalidate entire cache.""" self.cache_dict.clear() - if self.cache_file: - for p in self.cache_root.iterdir(): - p.unlink() def _init_cache(self, cache_file, max_size): self.cache = True @@ -66,13 +58,10 @@ class Cache(Generic[K, T]): link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - def get(self, obj, item: K, *args) -> T: - return self.first(obj, item, *args) - - def get_from_dict(self, obj, item: K, *args) -> T: + def get_from_dict(self, obj, *args) -> T: """Implements dict based cache.""" - # result = self.func(obj, item, *args) - cache_key = (self.key_func(item), *args) + pred_pseudo_label, y, *res_args = args + cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) link = self.cache_dict.get(cache_key) if link is not None: # Move the link to the front of the circular queue @@ -87,7 +76,7 @@ class Cache(Generic[K, T]): return result self.misses += 1 - result = self.func(obj, item, *args) + result = self.func(obj, *args) if self.full: # Use the old root to store the new key and result. @@ -113,16 +102,18 @@ class Cache(Generic[K, T]): def abl_cache( - cache: bool = True, cache_file: Union[None, str, PathLike] = None, - key_func: Callable[[K], Hashable] = lambda x: x, + key_func: Callable[[K], Hashable] = to_hashable, max_size: int = 4096, ): def decorator(func): - cache_instance = Cache(func, cache, cache_file, key_func, max_size) + cache_instance = Cache(func, cache_file, key_func, max_size) - def wrapper(self, *args, **kwargs): - return cache_instance.get(self, *args, **kwargs) + def wrapper(obj, *args): + if obj.use_cache: + return cache_instance.get_from_dict(obj, *args) + else: + return func(obj, *args) return wrapper From 7e5292eccbe3f03ae1660bb0d631a44f7cb8b5c4 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 15 Nov 2023 16:55:21 +0800 Subject: [PATCH 07/18] [MNT] use parameters of kb to initialize abl_cache --- abl/reasoning/kb.py | 16 ++++++++++++++-- abl/utils/cache.py | 45 +++++++++++++++++++++------------------------ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 3e5ca9e..1bca1b3 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -39,12 +39,24 @@ class KBBase(ABC): reasoning) will be automatically set up. """ - def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True): + def __init__( + self, + pseudo_label_list, + max_err=1e-10, + use_cache=True, + cache_file=None, + key_func=to_hashable, + max_cache_size=4096, + ): if not isinstance(pseudo_label_list, list): raise TypeError("pseudo_label_list should be list") self.pseudo_label_list = pseudo_label_list self.max_err = max_err + self.use_cache = use_cache + self.cache_file = cache_file + self.key_func = key_func + self.max_cache_size = max_cache_size @abstractmethod def logic_forward(self, pseudo_label): @@ -137,7 +149,7 @@ class KBBase(ABC): new_candidates.extend(candidates) return new_candidates - @abl_cache(max_size=4096) + @abl_cache() def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): """ Perform abductive reasoning by exhastive search. Specifically, begin with 0 and diff --git a/abl/utils/cache.py b/abl/utils/cache.py index a342d0a..418c93f 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -11,13 +11,7 @@ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields class Cache(Generic[K, T]): - def __init__( - self, - func: Callable[[K], T], - cache_file: Union[None, str, PathLike] = None, - key_func: Callable[[K], Hashable] = lambda x: x, - max_size: int = 4096, - ): + def __init__(self, func: Callable[[K], T]): """Create cache :param func: Function this cache evaluates @@ -26,9 +20,7 @@ class Cache(Generic[K, T]): :param key_func: Convert the key into a hashable object if needed """ self.func = func - self.key_func = key_func - - self._init_cache(cache_file, max_size) + self.has_init = False def __getitem__(self, obj, *args) -> T: return self.get_from_dict(obj, *args) @@ -37,27 +29,35 @@ class Cache(Generic[K, T]): """Invalidate entire cache.""" self.cache_dict.clear() - def _init_cache(self, cache_file, max_size): + def _init_cache(self, obj): + if self.has_init: + return + self.cache = True self.cache_dict = dict() + self.key_func = obj.key_func + self.cache_file = obj.cache_file + self.max_size = obj.max_cache_size - self.hits, self.misses, self.maxsize = 0, 0, max_size + self.hits, self.misses = 0, 0 self.full = False self.root = [] # root of the circular doubly linked list self.root[:] = [self.root, self.root, None, None] - if cache_file is not None: - with open(cache_file, "rb") as f: + if self.cache_file is not None: + with open(self.cache_file, "rb") as f: cache_dict_from_file = pickle.load(f) - self.maxsize += len(cache_dict_from_file) + self.max_size += len(cache_dict_from_file) print_log( - f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current" + f"Max size of the cache has been enlarged to {self.max_size}.", logger="current" ) for cache_key, result in cache_dict_from_file.items(): last = self.root[PREV] link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link + self.has_init = True + def get_from_dict(self, obj, *args) -> T: """Implements dict based cache.""" pred_pseudo_label, y, *res_args = args @@ -96,21 +96,18 @@ class Cache(Generic[K, T]): last = self.root[PREV] link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - if isinstance(self.maxsize, int): - self.full = len(self.cache_dict) >= self.maxsize + if isinstance(self.max_size, int): + self.full = len(self.cache_dict) >= self.max_size return result -def abl_cache( - cache_file: Union[None, str, PathLike] = None, - key_func: Callable[[K], Hashable] = to_hashable, - max_size: int = 4096, -): +def abl_cache(): def decorator(func): - cache_instance = Cache(func, cache_file, key_func, max_size) + cache_instance = Cache(func) def wrapper(obj, *args): if obj.use_cache: + cache_instance._init_cache(obj) return cache_instance.get_from_dict(obj, *args) else: return func(obj, *args) From aae657a3389be5db22f7a3ab95b3bc63b1611b1d Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 15 Nov 2023 21:38:51 +0800 Subject: [PATCH 08/18] [MNT] enable save of cache_dict for abl_cache --- abl/utils/cache.py | 16 +++- examples/mnist_add/mnist_add_example.ipynb | 94 +++++++++++++--------- 2 files changed, 70 insertions(+), 40 deletions(-) diff --git a/abl/utils/cache.py b/abl/utils/cache.py index 418c93f..ff8ad1a 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -1,9 +1,9 @@ import pickle -from os import PathLike -from typing import Callable, Generic, Hashable, TypeVar, Union +import os +import os.path as osp +from typing import Callable, Generic, TypeVar -from .logger import print_log -from .utils import to_hashable +from .logger import print_log, ABLLogger K = TypeVar("K") T = TypeVar("T") @@ -98,6 +98,14 @@ class Cache(Generic[K, T]): last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link if isinstance(self.max_size, int): self.full = len(self.cache_dict) >= self.max_size + if self.full: + log_dir = ABLLogger.get_current_instance().log_dir + cache_dir = osp.join(log_dir, "cache") + os.makedirs(cache_dir, exist_ok=True) + cache_path = osp.join(cache_dir, "cache.pth") + with open(cache_path, "wb") as file: + pickle.dump(self.cache_dict, file, protocol=pickle.HIGHEST_PROTOCOL) + print_log(f"Cache will be saved to {cache_path}", logger="current") return result diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index fc06d34..e94a347 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -6,6 +6,8 @@ "metadata": {}, "outputs": [], "source": [ + "import os.path as osp\n", + "\n", "import torch.nn as nn\n", "import torch\n", "\n", @@ -14,7 +16,7 @@ "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", "from abl.evaluation import SymbolMetric\n", - "from abl.utils import ABLLogger\n", + "from abl.utils import ABLLogger, print_log\n", "\n", "from examples.models.nn import LeNet5\n", "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" @@ -24,10 +26,22 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/15 21:35:55 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the MNIST Add example.\n" + ] + } + ], "source": [ "# Initialize logger\n", - "logger = ABLLogger.get_instance(\"abl\")" + "print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", + "\n", + "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", + "log_dir = ABLLogger.get_current_instance().log_dir\n", + "weights_dir = osp.join(log_dir, \"weights\")" ] }, { @@ -46,13 +60,10 @@ "source": [ "# Initialize knowledge base and abducer\n", "class add_KB(KBBase):\n", - " def __init__(self, pseudo_label_list=list(range(10)), max_err=0, use_cache=True):\n", - " super().__init__(pseudo_label_list, max_err, use_cache)\n", - "\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", - "kb = add_KB()\n", + "kb = add_KB(pseudo_label_list=list(range(10)))\n", "\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", "abducer = ReasonerBase(kb, dist_func=\"confidence\")" @@ -92,7 +103,6 @@ " criterion,\n", " optimizer,\n", " device,\n", - " save_interval=1,\n", " batch_size=32,\n", " num_epochs=1,\n", ")" @@ -182,45 +192,57 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "11/15 13:36:00 - abl - WARNING - Transform used in the training phase will be used in prediction.\n" - ] - }, - { - "ename": "TypeError", - "evalue": "Input must be of type list.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb 单元格 17\u001b[0m line \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m bridge\u001b[39m.\u001b[39;49mtrain(train_data, loops\u001b[39m=\u001b[39;49m\u001b[39m5\u001b[39;49m, segment_size\u001b[39m=\u001b[39;49m\u001b[39m10000\u001b[39;49m)\n\u001b[1;32m 2\u001b[0m bridge\u001b[39m.\u001b[39mtest(test_data)\n", - "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:92\u001b[0m, in \u001b[0;36mSimpleBridge.train\u001b[0;34m(self, train_data, loops, segment_size, eval_interval, save_interval, save_dir)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict(sub_data_samples)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39midx_to_pseudo_label(sub_data_samples)\n\u001b[0;32m---> 92\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce_pseudo_label(sub_data_samples)\n\u001b[1;32m 93\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpseudo_label_to_idx(sub_data_samples)\n\u001b[1;32m 94\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mtrain(sub_data_samples)\n", - "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:36\u001b[0m, in \u001b[0;36mSimpleBridge.abduce_pseudo_label\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce_pseudo_label\u001b[39m(\n\u001b[1;32m 31\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 32\u001b[0m data_samples: ListData,\n\u001b[1;32m 33\u001b[0m max_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\n\u001b[1;32m 34\u001b[0m require_more_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m,\n\u001b[1;32m 35\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[List[Any]]:\n\u001b[0;32m---> 36\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabducer\u001b[39m.\u001b[39;49mbatch_abduce(data_samples, max_revision, require_more_revision)\n\u001b[1;32m 37\u001b[0m \u001b[39mreturn\u001b[39;00m data_samples[\u001b[39m\"\u001b[39m\u001b[39mabduced_pseudo_label\u001b[39m\u001b[39m\"\u001b[39m]\n", - "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:246\u001b[0m, in \u001b[0;36mReasonerBase.batch_abduce\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 247\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", - "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:247\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[0;32m--> 247\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", - "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:222\u001b[0m, in \u001b[0;36mReasonerBase.abduce\u001b[0;34m(self, pred_prob, pred_pseudo_label, y, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce\u001b[39m(\n\u001b[1;32m 194\u001b[0m \u001b[39mself\u001b[39m, pred_prob, pred_pseudo_label, y, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 195\u001b[0m ):\n\u001b[1;32m 196\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data.\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[39m knowledge base.\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 222\u001b[0m symbol_num \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(flatten(pred_pseudo_label))\n\u001b[1;32m 223\u001b[0m max_revision_num \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_max_revision_num(max_revision, symbol_num)\n\u001b[1;32m 225\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39muse_zoopt:\n", - "File \u001b[0;32m~/ABL-Package/abl/utils/utils.py:26\u001b[0m, in \u001b[0;36mflatten\u001b[0;34m(nested_list)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[39mFlattens a nested list.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m If the input object is not a list.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list, \u001b[39mlist\u001b[39m):\n\u001b[0;32m---> 26\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput must be of type list.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m nested_list \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list[\u001b[39m0\u001b[39m], (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)):\n\u001b[1;32m 29\u001b[0m \u001b[39mreturn\u001b[39;00m nested_list\n", - "\u001b[0;31mTypeError\u001b[0m: Input must be of type list." + "11/15 21:36:18 - abl - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - Transform used in the training phase will be used in prediction.\n", + "11/15 21:36:21 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [1/3] model loss is 1.80390\n", + "11/15 21:36:24 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [2/3] model loss is 1.41898\n", + "11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [3/3] model loss is 1.08221\n", + "11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n", + "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.590 \n", + "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n", + "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_1.pth\n", + "11/15 21:36:29 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [1/3] model loss is 0.65210\n", + "11/15 21:36:31 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [2/3] model loss is 0.13546\n", + "11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [3/3] model loss is 0.08060\n", + "11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n", + "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.982 \n", + "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n", + "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_2.pth\n", + "11/15 21:36:35 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [1/3] model loss is 0.06446\n", + "11/15 21:36:37 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [2/3] model loss is 0.05224\n", + "11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [3/3] model loss is 0.05119\n", + "11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n", + "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.989 \n", + "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n", + "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_3.pth\n", + "11/15 21:36:42 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [1/3] model loss is 0.04667\n", + "11/15 21:36:44 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [2/3] model loss is 0.04027\n", + "11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [3/3] model loss is 0.03672\n", + "11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [4]\n", + "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.990 \n", + "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [4]\n", + "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_4.pth\n", + "11/15 21:36:48 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [1/3] model loss is 0.03381\n", + "11/15 21:36:50 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [2/3] model loss is 0.03333\n", + "11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [3/3] model loss is 0.03195\n", + "11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [5]\n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.992 \n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [5]\n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_5.pth\n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.988 \n" ] } ], "source": [ - "bridge.train(train_data, loops=5, segment_size=10000)\n", + "bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -239,7 +261,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" }, "orig_nbformat": 4, "vscode": { From 110b1494554a784ffe2af21914ac20c1c3d63b9a Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 15 Nov 2023 22:06:33 +0800 Subject: [PATCH 09/18] [MNT] unify index of ListData, remove seg of valid --- abl/bridge/simple_bridge.py | 32 ++++----- abl/evaluation/semantics_metric.py | 2 +- abl/evaluation/symbol_metric.py | 4 +- abl/learning/abl_model.py | 4 +- abl/utils/cache.py | 2 +- examples/mnist_add/mnist_add_example.ipynb | 76 ++++------------------ 6 files changed, 34 insertions(+), 86 deletions(-) diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 9093bc1..ac39df1 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -25,7 +25,7 @@ class SimpleBridge(BaseBridge): def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: self.model.predict(data_samples) - return data_samples["pred_idx"], data_samples.get("pred_prob", None) + return data_samples.pred_idx, data_samples.get("pred_prob", None) def abduce_pseudo_label( self, @@ -34,7 +34,7 @@ class SimpleBridge(BaseBridge): require_more_revision: int = 0, ) -> List[List[Any]]: self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) - return data_samples["abduced_pseudo_label"] + return data_samples.abduced_pseudo_label def idx_to_pseudo_label( self, data_samples: ListData, mapping: Optional[Dict] = None @@ -45,7 +45,7 @@ class SimpleBridge(BaseBridge): data_samples.pred_pseudo_label = [ [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx ] - return data_samples["pred_pseudo_label"] + return data_samples.pred_pseudo_label def pseudo_label_to_idx( self, data_samples: ListData, mapping: Optional[Dict] = None @@ -57,7 +57,7 @@ class SimpleBridge(BaseBridge): for sub_list in data_samples.abduced_pseudo_label ] data_samples.abduced_idx = abduced_idx - return data_samples["abduced_idx"] + return data_samples.abduced_idx def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: data_samples = ListData() @@ -104,16 +104,16 @@ class SimpleBridge(BaseBridge): if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") - self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) + self.model.save( + save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") + ) - def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: - for seg_idx in range((len(data_samples) - 1) // batch_size + 1): - sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size] - self.predict(sub_data_samples) - self.idx_to_pseudo_label(sub_data_samples) + def _valid(self, data_samples: ListData) -> None: + self.predict(data_samples) + self.idx_to_pseudo_label(data_samples) - for metric in self.metric_list: - metric.process(sub_data_samples) + for metric in self.metric_list: + metric.process(data_samples) res = dict() for metric in self.metric_list: @@ -123,12 +123,12 @@ class SimpleBridge(BaseBridge): msg += k + f": {v:.3f} " print_log(msg, logger="current") - def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None: + def valid(self, valid_data: Union[ListData, DataSet]) -> None: if not isinstance(valid_data, ListData): data_samples = self.data_preprocess(*valid_data) else: data_samples = valid_data - self._valid(data_samples, batch_size=batch_size) + self._valid(data_samples) - def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None: - self.valid(test_data, batch_size=batch_size) + def test(self, test_data: Union[ListData, DataSet]) -> None: + self.valid(test_data) diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index ae7aac8..21ecabf 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -11,7 +11,7 @@ class SemanticsMetric(BaseMetric): def process(self, data_samples: Sequence[dict]) -> None: for data_sample in data_samples: - if self.kb.check_equal(data_sample, data_sample["Y"][0]): + if self.kb.check_equal(data_sample, data_sample.Y[0]): self.results.append(1) else: self.results.append(0) diff --git a/abl/evaluation/symbol_metric.py b/abl/evaluation/symbol_metric.py index c2d7938..a160bb2 100644 --- a/abl/evaluation/symbol_metric.py +++ b/abl/evaluation/symbol_metric.py @@ -8,9 +8,9 @@ class SymbolMetric(BaseMetric): super().__init__(prefix) def process(self, data_samples: Sequence[dict]) -> None: - pred_pseudo_label = data_samples["pred_pseudo_label"] + pred_pseudo_label = data_samples.pred_pseudo_label - gt_pseudo_label = data_samples["gt_pseudo_label"] + gt_pseudo_label = data_samples.gt_pseudo_label if not len(pred_pseudo_label) == len(gt_pseudo_label): raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 6685cc4..ab7bfb7 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -69,11 +69,11 @@ class ABLModel: if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) - prob = reform_idx(prob, data_samples["X"]) + prob = reform_idx(prob, data_samples.X) else: prob = None label = model.predict(X=data_X) - label = reform_idx(label, data_samples["X"]) + label = reform_idx(label, data_samples.X) data_samples.pred_idx = label if prob is not None: diff --git a/abl/utils/cache.py b/abl/utils/cache.py index ff8ad1a..bbd15cf 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -102,7 +102,7 @@ class Cache(Generic[K, T]): log_dir = ABLLogger.get_current_instance().log_dir cache_dir = osp.join(log_dir, "cache") os.makedirs(cache_dir, exist_ok=True) - cache_path = osp.join(cache_dir, "cache.pth") + cache_path = osp.join(cache_dir, "abduce_by_search_cache_res.pth") with open(cache_path, "wb") as file: pickle.dump(self.cache_dict, file, protocol=pickle.HIGHEST_PROTOCOL) print_log(f"Cache will be saved to {cache_path}", logger="current") diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index e94a347..295a3d2 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -24,17 +24,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "11/15 21:35:55 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the MNIST Add example.\n" - ] - } - ], + "outputs": [], "source": [ "# Initialize logger\n", "print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", @@ -54,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -79,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -118,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -138,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -156,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -175,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -192,53 +184,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "11/15 21:36:18 - abl - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - Transform used in the training phase will be used in prediction.\n", - "11/15 21:36:21 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [1/3] model loss is 1.80390\n", - "11/15 21:36:24 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [2/3] model loss is 1.41898\n", - "11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [3/3] model loss is 1.08221\n", - "11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n", - "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.590 \n", - "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n", - "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_1.pth\n", - "11/15 21:36:29 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [1/3] model loss is 0.65210\n", - "11/15 21:36:31 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [2/3] model loss is 0.13546\n", - "11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [3/3] model loss is 0.08060\n", - "11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n", - "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.982 \n", - "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n", - "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_2.pth\n", - "11/15 21:36:35 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [1/3] model loss is 0.06446\n", - "11/15 21:36:37 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [2/3] model loss is 0.05224\n", - "11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [3/3] model loss is 0.05119\n", - "11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n", - "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.989 \n", - "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n", - "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_3.pth\n", - "11/15 21:36:42 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [1/3] model loss is 0.04667\n", - "11/15 21:36:44 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [2/3] model loss is 0.04027\n", - "11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [3/3] model loss is 0.03672\n", - "11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [4]\n", - "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.990 \n", - "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [4]\n", - "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_4.pth\n", - "11/15 21:36:48 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [1/3] model loss is 0.03381\n", - "11/15 21:36:50 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [2/3] model loss is 0.03333\n", - "11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [3/3] model loss is 0.03195\n", - "11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [5]\n", - "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.992 \n", - "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [5]\n", - "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_5.pth\n", - "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.988 \n" - ] - } - ], + "outputs": [], "source": [ "bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" From c87d3772f5f2c9241bd716d6293688235eef2874 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 15 Nov 2023 22:41:12 +0800 Subject: [PATCH 10/18] [FIX] fix bug in SemanticMetric --- abl/evaluation/semantics_metric.py | 6 ++++-- examples/mnist_add/mnist_add_example.ipynb | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 21ecabf..271ca1c 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -10,8 +10,10 @@ class SemanticsMetric(BaseMetric): self.kb = kb def process(self, data_samples: Sequence[dict]) -> None: - for data_sample in data_samples: - if self.kb.check_equal(data_sample, data_sample.Y[0]): + pred_psedudo_label_list = data_samples.pred_pseudo_label + y_list = data_samples.Y + for pred_psedudo_label, y in zip(pred_psedudo_label_list, y_list): + if self.kb._check_equal(self.kb.logic_forward(pred_psedudo_label), y): self.results.append(1) else: self.results.append(0) diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 295a3d2..0927cb5 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -15,7 +15,7 @@ "\n", "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", - "from abl.evaluation import SymbolMetric\n", + "from abl.evaluation import SymbolMetric, SemanticsMetric\n", "from abl.utils import ABLLogger, print_log\n", "\n", "from examples.models.nn import LeNet5\n", @@ -135,7 +135,7 @@ "outputs": [], "source": [ "# Add metric\n", - "metric = [SymbolMetric(prefix=\"mnist_add\")]" + "metric = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]" ] }, { From c7ac3ba0b8859d9e4976e83bf62ebd05a95d6743 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 16 Nov 2023 09:17:39 +0800 Subject: [PATCH 11/18] [FIX] fix typo --- abl/evaluation/semantics_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 271ca1c..0e82ee9 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -12,8 +12,8 @@ class SemanticsMetric(BaseMetric): def process(self, data_samples: Sequence[dict]) -> None: pred_psedudo_label_list = data_samples.pred_pseudo_label y_list = data_samples.Y - for pred_psedudo_label, y in zip(pred_psedudo_label_list, y_list): - if self.kb._check_equal(self.kb.logic_forward(pred_psedudo_label), y): + for pred_pseudo_label, y in zip(pred_psedudo_label_list, y_list): + if self.kb._check_equal(self.kb.logic_forward(pred_pseudo_label), y): self.results.append(1) else: self.results.append(0) From dcb24dec55c87ad9d2fc9c662230fb7a22daee3b Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 16 Nov 2023 09:18:08 +0800 Subject: [PATCH 12/18] [FIX] fix typo --- abl/evaluation/semantics_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 0e82ee9..14c4f46 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -10,9 +10,9 @@ class SemanticsMetric(BaseMetric): self.kb = kb def process(self, data_samples: Sequence[dict]) -> None: - pred_psedudo_label_list = data_samples.pred_pseudo_label + pred_pseudo_label_list = data_samples.pred_pseudo_label y_list = data_samples.Y - for pred_pseudo_label, y in zip(pred_psedudo_label_list, y_list): + for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list): if self.kb._check_equal(self.kb.logic_forward(pred_pseudo_label), y): self.results.append(1) else: From fa0d2a7dd0a5e1864cd36cca50d6e0b43df797a3 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 16 Nov 2023 09:22:19 +0800 Subject: [PATCH 13/18] [FIX] remove obsolete function in utils --- abl/utils/utils.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 1480045..1a6f615 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -144,34 +144,6 @@ def block_sample(X, Z, Y, sample_num, seg_idx): return (data[start_idx:end_idx] for data in (X, Z, Y)) -def check_equal(a, b, max_err=0): - """ - Check whether two numbers a and b are equal within a maximum allowable error. - - Parameters - ---------- - a, b : int or float - The numbers to compare. - max_err : int or float, optional - The maximum allowable absolute difference between a and b for them to be considered equal. - Default is 0, meaning the numbers must be exactly equal. - - Returns - ------- - bool - True if a and b are equal within the allowable error, False otherwise. - - Raises - ------ - TypeError - If a or b are not of type int or float. - """ - if not (isinstance(a, (int, float)) and isinstance(b, (int, float))): - raise TypeError("Input values must be int or float.") - - return abs(a - b) <= max_err - - def to_hashable(x): """ Convert a nested list to a nested tuple so it is hashable. From 3b92fd5b419b850a8248bd81faaa0e3d9c43cc4a Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 16 Nov 2023 15:53:53 +0800 Subject: [PATCH 14/18] [MNT] resolve some comments --- abl/learning/abl_model.py | 6 ++--- abl/learning/basic_nn.py | 50 +++++++++++++-------------------------- abl/reasoning/kb.py | 6 ++--- abl/reasoning/reasoner.py | 4 ++-- abl/utils/utils.py | 2 +- 5 files changed, 26 insertions(+), 42 deletions(-) diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index ab7bfb7..97775c0 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -13,7 +13,7 @@ import pickle from typing import Any, Dict from ..structures import ListData -from ..utils import reform_idx +from ..utils import reform_list class ABLModel: @@ -69,11 +69,11 @@ class ABLModel: if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) - prob = reform_idx(prob, data_samples.X) + prob = reform_list(prob, data_samples.X) else: prob = None label = model.predict(X=data_X) - label = reform_idx(label, data_samples.X) + label = reform_list(label, data_samples.X) data_samples.pred_idx = label if prob is not None: diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index b1da93c..0b43fcb 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -66,7 +66,8 @@ class BasicNN: num_workers: int = 0, save_interval: Optional[int] = None, save_dir: Optional[str] = None, - transform: Callable[..., Any] = None, + train_transform: Callable[..., Any] = None, + test_transform: Callable[..., Any] = None, collate_fn: Callable[[List[T]], Any] = None, ) -> None: self.model = model.to(device) @@ -79,9 +80,18 @@ class BasicNN: self.num_workers = num_workers self.save_interval = save_interval self.save_dir = save_dir - self.transform = transform + self.train_transform = train_transform + self.test_transform = test_transform self.collate_fn = collate_fn + if self.train_transform is not None and self.test_transform is None: + print_log( + "Transform used in the training phase will be used in prediction.", + "current", + level=logging.WARNING, + ) + self.test_transform = self.train_transform + def _fit(self, data_loader) -> float: """ Internal method to fit the model on data for n epochs, with early stopping. @@ -198,12 +208,7 @@ class BasicNN: return torch.cat(results, axis=0) - def predict( - self, - data_loader: DataLoader = None, - X: List[Any] = None, - test_transform: Callable[..., Any] = None, - ) -> numpy.ndarray: + def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: """ Predict the class of the input data. @@ -221,15 +226,7 @@ class BasicNN: """ if data_loader is None: - if test_transform is None: - print_log( - "Transform used in the training phase will be used in prediction.", - "current", - level=logging.WARNING, - ) - dataset = PredictionDataset(X, self.transform) - else: - dataset = PredictionDataset(X, test_transform) + dataset = PredictionDataset(X, self.test_transform) data_loader = DataLoader( dataset, batch_size=self.batch_size, @@ -238,12 +235,7 @@ class BasicNN: ) return self._predict(data_loader).argmax(axis=1).cpu().numpy() - def predict_proba( - self, - data_loader: DataLoader = None, - X: List[Any] = None, - test_transform: Callable[..., Any] = None, - ) -> numpy.ndarray: + def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: """ Predict the probability of each class for the input data. @@ -261,15 +253,7 @@ class BasicNN: """ if data_loader is None: - if test_transform is None: - print_log( - "Transform used in the training phase will be used in prediction.", - "current", - level=logging.WARNING, - ) - dataset = PredictionDataset(X, self.transform) - else: - dataset = PredictionDataset(X, test_transform) + dataset = PredictionDataset(X, self.test_transform) data_loader = DataLoader( dataset, batch_size=self.batch_size, @@ -379,7 +363,7 @@ class BasicNN: if not (len(y) == len(X)): raise ValueError("X and y should have equal length.") - dataset = ClassificationDataset(X, y, transform=self.transform) + dataset = ClassificationDataset(X, y, transform=self.train_transform) data_loader = DataLoader( dataset, batch_size=self.batch_size, diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 1bca1b3..10aa559 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -9,7 +9,7 @@ from functools import lru_cache import numpy as np import pyswip -from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable +from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable, restore_from_hashable from ..utils.cache import abl_cache @@ -390,7 +390,7 @@ class PrologKB(KBBase): for idx in revision_idx: revision_pred_pseudo_label[idx] = "P" + str(idx) - revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) + revision_pred_pseudo_label = reform_list(revision_pred_pseudo_label, pred_pseudo_label) regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) @@ -423,7 +423,7 @@ class PrologKB(KBBase): candidate = pred_pseudo_label.copy() for i, idx in enumerate(revision_idx): candidate[idx] = c[i] - candidate = reform_idx(candidate, save_pred_pseudo_label) + candidate = reform_list(candidate, save_pred_pseudo_label) candidates.append(candidate) return candidates diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 686e9dd..2e57570 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -3,7 +3,7 @@ from zoopt import Dimension, Objective, Parameter, Opt from ..utils.utils import ( confidence_dist, flatten, - reform_idx, + reform_list, hamming_dist, ) @@ -542,7 +542,7 @@ if __name__ == "__main__": return candidate def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): - all_revision_flag = reform_idx(sol.get_x(), pred_res) + all_revision_flag = reform_list(sol.get_x(), pred_res) lefted_idxs = [i for i in range(len(pred_res))] candidate_size = [] while lefted_idxs: diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 1480045..0b5dff4 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -31,7 +31,7 @@ def flatten(nested_list): return list(chain.from_iterable(nested_list)) -def reform_idx(flattened_list, structured_list): +def reform_list(flattened_list, structured_list): """ Reform the index based on structured_list structure. From d72fc51bbd4566af60729980ee95734b5ca9f247 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 16 Nov 2023 16:04:01 +0800 Subject: [PATCH 15/18] [ENH] refine reasoning test --- abl/reasoning/reasoner.py | 222 ++++++++++++++++---------------------- 1 file changed, 94 insertions(+), 128 deletions(-) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 2e57570..b16a595 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -219,13 +219,13 @@ class ReasonerBase: A revised pseudo label through abductive reasoning, which is consistent with the knowledge base. """ - symbol_num = data_sample.elements_num("pred_pseudo_label") - max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - pred_pseudo_label = data_sample.pred_pseudo_label[0] pred_prob = data_sample.pred_prob[0] y = data_sample.Y[0] + symbol_num = len(flatten(pred_pseudo_label)) + max_revision_num = self._get_max_revision_num(max_revision, symbol_num) + if self.use_zoopt: solution = self.zoopt_get_solution( symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num @@ -275,12 +275,11 @@ class ReasonerBase: if __name__ == "__main__": from kb import KBBase, GroundKB, PrologKB - - prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + from abl.structures import ListData - prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + ################################ + # Test for MNIST Add reasoning # + ################################ class AddKB(KBBase): def __init__(self, pseudo_label_list=list(range(10)), @@ -290,38 +289,54 @@ if __name__ == "__main__": def logic_forward(self, nums): return sum(nums) - class AddGroundKB(GroundKB): + class AddGroundKB(GroundKB, AddKB): def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): super().__init__(pseudo_label_list, GKB_len_list) + + def logic_forward(self, nums): + return sum(nums) + + def logic_forward(self, nums): return sum(nums) def test_add(reasoner): - res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + # favor 1 in first one + prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + + # favor 7 in first one + prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + + data_samples_add = ListData() + data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] + data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] + data_samples_add.Y = [8, 8, 17, 10] + + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) + print(res) + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) print(res) + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) + print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9] print() - print("AddKB with GKB:") + print("AddGroundKB:") kb = AddGroundKB() reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) - print("AddKB without GKB:") + print("AddKB:") kb = AddKB() reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) - print("AddKB without GKB, no cache") + print("AddKB, no cache") kb = AddKB(use_cache=False) reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) @@ -339,45 +354,20 @@ if __name__ == "__main__": ) reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) test_add(reasoner) - - print("AddKB with multiple inputs at once:") - multiple_prob = [[ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ], - [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ]] - - kb = AddKB() - reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=1, - ) - print(res) - print() - + + ################################ + #### Test for HWF reasoning #### + ################################ + class HwfKB(KBBase): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"], max_err=1e-3, + use_cache=False, ): - super().__init__(pseudo_label_list, max_err) + super().__init__(pseudo_label_list, max_err, use_cache) def _valid_candidate(self, formula): if len(formula) % 2 == 0: @@ -397,7 +387,7 @@ if __name__ == "__main__": formula = [mapping[f] for f in formula] return eval("".join(formula)) - class HwfGroundKB(GroundKB): + class HwfGroundKB(GroundKB, HwfKB): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", @@ -407,6 +397,7 @@ if __name__ == "__main__": ): super().__init__(pseudo_label_list, GKB_len_list, max_err) + def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False @@ -416,6 +407,17 @@ if __name__ == "__main__": if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: return False return True + + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: + return False + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: + return False + return True + def logic_forward(self, formula): if not self._valid_candidate(formula): @@ -425,88 +427,56 @@ if __name__ == "__main__": formula = [mapping[f] for f in formula] return eval("".join(formula)) + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return None + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) + formula = [mapping[f] for f in formula] + return eval("".join(formula)) + def test_hwf(reasoner): - res = reasoner.batch_abduce( - [None], - [["5", "+", "2"]], - [3], - max_revision=2, - require_more_revision=0, - ) + data_samples_hwf = ListData() + data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] + data_samples_hwf.pred_prob = [None, None, None, None] + data_samples_hwf.Y = [3, 64, 65, 3.17] + + res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) print(res) - res = reasoner.batch_abduce( - [None], - [["5", "+", "9"]], - [65], - max_revision=3, - require_more_revision=0, - ) + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) print(res) - res = reasoner.batch_abduce( - [None], - [["5", "8", "8", "8", "8"]], - [3.17], - max_revision=5, - require_more_revision=3, - ) + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) print(res) print() - def test_hwf_multiple(reasoner, max_revisions): - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[0], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[1], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 65], - max_revision=max_revisions[2], - require_more_revision=0, - ) - print(res) - print() - print("HwfKB with GKB, max_err=0.1") + print("HwfGroundKB, max_err=0.1:") kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB without GKB, max_err=0.1") + print("HwfKB, max_err=0.1:") kb = HwfKB(max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB with GKB, max_err=1") + print("HwfGroundKB, max_err=1:") kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB without GKB, max_err=1") + print("HwfKB, max_err=1:") kb = HwfKB(max_err=1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - - print("HwfKB with multiple inputs at once:") - kb = HwfKB(max_err=0.1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf_multiple(reasoner, max_revisions=[1,3,3]) - print("max_revision is float") - test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) - + + ################################ + #### Test for HED reasoning #### + ################################ + + class HedKB(PrologKB): def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list, pl_file) @@ -599,28 +569,24 @@ if __name__ == "__main__": inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] - print("HedKB logic forward") - print(kb.logic_forward(consist_exs)) + print("HedKB logic forward:") + print(kb.logic_forward(consist_exs), end=" ") print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) print() - print("HedKB consist rule") - print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) + print("HedKB consist rule:") + print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ") print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) print() + data_sample_hed = ListData() + data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2] + data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] + data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] + print("HedReasoner abduce") - res = reasoner.abduce( - [[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) - ) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) - ) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) - ) - print(res) + res = reasoner.batch_abduce(data_sample_hed) + for r in res: + print(r) print() print("HedReasoner abduce rules") From 5887cbdb97cf16ab2eddf443aedad12f752eb638 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 16 Nov 2023 16:23:48 +0800 Subject: [PATCH 16/18] [ENH] remove list and len constraints in ListData --- abl/bridge/simple_bridge.py | 2 +- abl/learning/abl_model.py | 3 +- abl/reasoning/reasoner.py | 6 ++-- abl/structures/list_data.py | 66 ++++++++++++++----------------------- 4 files changed, 30 insertions(+), 47 deletions(-) diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index ac39df1..76a492d 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -25,7 +25,7 @@ class SimpleBridge(BaseBridge): def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: self.model.predict(data_samples) - return data_samples.pred_idx, data_samples.get("pred_prob", None) + return data_samples.pred_idx, data_samples.pred_prob def abduce_pseudo_label( self, diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 97775c0..bcf03df 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -76,8 +76,7 @@ class ABLModel: label = reform_list(label, data_samples.X) data_samples.pred_idx = label - if prob is not None: - data_samples.pred_prob = prob + data_samples.pred_prob = prob return {"label": label, "prob": prob} diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 2e57570..962c087 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -222,9 +222,9 @@ class ReasonerBase: symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - pred_pseudo_label = data_sample.pred_pseudo_label[0] - pred_prob = data_sample.pred_prob[0] - y = data_sample.Y[0] + pred_pseudo_label = data_sample.pred_pseudo_label + pred_prob = data_sample.pred_prob + y = data_sample.Y if self.use_zoopt: solution = self.zoopt_get_solution( diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py index 2571a13..a53ffc5 100644 --- a/abl/structures/list_data.py +++ b/abl/structures/list_data.py @@ -132,22 +132,21 @@ class ListData(BaseDataElement): super().__setattr__(name, value) else: raise AttributeError( - f"{name} has been used as a " - "private attribute, which is immutable." + f"{name} has been used as a " "private attribute, which is immutable." ) else: - assert isinstance(value, list), "value must be of type `list`" + # assert isinstance(value, list), "value must be of type `list`" - if len(self) > 0: - assert len(value) == len(self), ( - "The length of " - f"values {len(value)} is " - "not consistent with " - "the length of this " - ":obj:`ListData` " - f"{len(self)}" - ) + # if len(self) > 0: + # assert len(value) == len(self), ( + # "The length of " + # f"values {len(value)} is " + # "not consistent with " + # "the length of this " + # ":obj:`ListData` " + # f"{len(self)}" + # ) super().__setattr__(name, value) __setitem__ = __setattr__ @@ -176,32 +175,15 @@ class ListData(BaseDataElement): if isinstance(item, str): return getattr(self, item) - if isinstance(item, int): - if item >= len(self) or item < -len(self): # type:ignore - raise IndexError(f"Index {item} out of range!") - else: - # keep the dimension - item = slice(item, None, len(self)) - new_data = self.__class__(metainfo=self.metainfo) + if isinstance(item, torch.Tensor): - assert item.dim() == 1, ( - "Only support to get the" " values along the first dimension." - ) - if isinstance(item, BoolTypeTensor.__args__): - assert len(item) == len(self), ( - "The shape of the " - "input(BoolTensor) " - f"{len(item)} " - "does not match the shape " - "of the indexed tensor " - "in results_field " - f"{len(self)} at " - "first dimension." - ) + assert item.dim() == 1, "Only support to get the" " values along the first dimension." for k, v in self.items(): - if isinstance(v, torch.Tensor): + if v is None: + new_data[k] = None + elif isinstance(v, torch.Tensor): new_data[k] = v[item] elif isinstance(v, np.ndarray): new_data[k] = v[item.cpu().numpy()] @@ -235,9 +217,12 @@ class ListData(BaseDataElement): ) else: - # item is a slice + # item is a slice or int for k, v in self.items(): - new_data[k] = v[item] + if v is None: + new_data[k] = None + else: + new_data[k] = v[item] return new_data # type:ignore @staticmethod @@ -289,8 +274,7 @@ class ListData(BaseDataElement): new_values = v0.cat(values) else: raise ValueError( - f"The type of `{k}` is `{type(v0)}` which has no " - "attribute of `cat`" + f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`" ) new_data[k] = new_values return new_data # type:ignore @@ -302,15 +286,15 @@ class ListData(BaseDataElement): list: Flattened data fields. """ return flatten_list(self[item]) - + def elements_num(self, item: IndexType) -> int: """int: The number of elements in self[item].""" return len(self.flatten(item)) - + def to_tuple(self, item: IndexType) -> tuple: """tuple: The data fields in self[item] converted to tuple.""" return to_hashable(self[item]) - + def __len__(self) -> int: """int: The length of ListData.""" if len(self._data_fields) > 0: From 3c02f4e1f34e401f3cd3c56813779a2f3a2f2025 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 16 Nov 2023 20:33:10 +0800 Subject: [PATCH 17/18] [MNT] delete cache_file --- abl/reasoning/kb.py | 2 -- abl/utils/cache.py | 21 --------------------- 2 files changed, 23 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 10aa559..8ef6a25 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -44,7 +44,6 @@ class KBBase(ABC): pseudo_label_list, max_err=1e-10, use_cache=True, - cache_file=None, key_func=to_hashable, max_cache_size=4096, ): @@ -54,7 +53,6 @@ class KBBase(ABC): self.max_err = max_err self.use_cache = use_cache - self.cache_file = cache_file self.key_func = key_func self.max_cache_size = max_cache_size diff --git a/abl/utils/cache.py b/abl/utils/cache.py index bbd15cf..a927e1f 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -36,7 +36,6 @@ class Cache(Generic[K, T]): self.cache = True self.cache_dict = dict() self.key_func = obj.key_func - self.cache_file = obj.cache_file self.max_size = obj.max_cache_size self.hits, self.misses = 0, 0 @@ -44,18 +43,6 @@ class Cache(Generic[K, T]): self.root = [] # root of the circular doubly linked list self.root[:] = [self.root, self.root, None, None] - if self.cache_file is not None: - with open(self.cache_file, "rb") as f: - cache_dict_from_file = pickle.load(f) - self.max_size += len(cache_dict_from_file) - print_log( - f"Max size of the cache has been enlarged to {self.max_size}.", logger="current" - ) - for cache_key, result in cache_dict_from_file.items(): - last = self.root[PREV] - link = [last, self.root, cache_key, result] - last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - self.has_init = True def get_from_dict(self, obj, *args) -> T: @@ -98,14 +85,6 @@ class Cache(Generic[K, T]): last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link if isinstance(self.max_size, int): self.full = len(self.cache_dict) >= self.max_size - if self.full: - log_dir = ABLLogger.get_current_instance().log_dir - cache_dir = osp.join(log_dir, "cache") - os.makedirs(cache_dir, exist_ok=True) - cache_path = osp.join(cache_dir, "abduce_by_search_cache_res.pth") - with open(cache_path, "wb") as file: - pickle.dump(self.cache_dict, file, protocol=pickle.HIGHEST_PROTOCOL) - print_log(f"Cache will be saved to {cache_path}", logger="current") return result From e00714c70592d940105b051146cf802d626b4285 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 16 Nov 2023 21:34:16 +0800 Subject: [PATCH 18/18] [MNT] change abducer to reasoner --- abl/bridge/base_bridge.py | 8 +- abl/bridge/simple_bridge.py | 12 +- abl/learning/basic_nn.py | 13 +- examples/hed/hed_bridge.py | 24 +-- examples/hwf/hwf_example.ipynb | 169 +++++++++++++++++---- examples/mnist_add/mnist_add_example.ipynb | 6 +- 6 files changed, 171 insertions(+), 61 deletions(-) diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index 869ea39..4ea87ac 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -9,22 +9,22 @@ DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] class BaseBridge(metaclass=ABCMeta): - def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: + def __init__(self, model: ABLModel, reasoner: ReasonerBase) -> None: if not isinstance(model, ABLModel): raise TypeError( "Expected an instance of ABLModel, but received type: {}".format( type(model) ) ) - if not isinstance(abducer, ReasonerBase): + if not isinstance(reasoner, ReasonerBase): raise TypeError( "Expected an instance of ReasonerBase, but received type: {}".format( - type(abducer) + type(reasoner) ) ) self.model = model - self.abducer = abducer + self.reasoner = reasoner @abstractmethod def predict( diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 76a492d..ff33376 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -15,13 +15,13 @@ class SimpleBridge(BaseBridge): def __init__( self, model: ABLModel, - abducer: ReasonerBase, + reasoner: ReasonerBase, metric_list: List[BaseMetric], ) -> None: - super().__init__(model, abducer) + super().__init__(model, reasoner) self.metric_list = metric_list - # TODO: add abducer.mapping to the property of SimpleBridge + # TODO: add reasoner.mapping to the property of SimpleBridge def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: self.model.predict(data_samples) @@ -33,14 +33,14 @@ class SimpleBridge(BaseBridge): max_revision: int = -1, require_more_revision: int = 0, ) -> List[List[Any]]: - self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) + self.reasoner.batch_abduce(data_samples, max_revision, require_more_revision) return data_samples.abduced_pseudo_label def idx_to_pseudo_label( self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: - mapping = self.abducer.mapping + mapping = self.reasoner.mapping pred_idx = data_samples.pred_idx data_samples.pred_pseudo_label = [ [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx @@ -51,7 +51,7 @@ class SimpleBridge(BaseBridge): self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: - mapping = self.abducer.remapping + mapping = self.reasoner.remapping abduced_idx = [ [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] for sub_list in data_samples.abduced_pseudo_label diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 0b43fcb..115b098 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -92,7 +92,7 @@ class BasicNN: ) self.test_transform = self.train_transform - def _fit(self, data_loader) -> float: + def _fit(self, data_loader: DataLoader) -> float: """ Internal method to fit the model on data for n epochs, with early stopping. @@ -180,7 +180,7 @@ class BasicNN: return total_loss / total_num - def _predict(self, data_loader) -> torch.Tensor: + def _predict(self, data_loader: DataLoader) -> torch.Tensor: """ Internal method to predict the outputs given a DataLoader. @@ -262,7 +262,7 @@ class BasicNN: ) return self._predict(data_loader).softmax(axis=1).cpu().numpy() - def _score(self, data_loader) -> Tuple[float, float]: + def _score(self, data_loader: DataLoader) -> Tuple[float, float]: """ Internal method to compute loss and accuracy for the data provided through a DataLoader. @@ -334,12 +334,7 @@ class BasicNN: print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") return accuracy - def _data_loader( - self, - X: List[Any], - y: List[int] = None, - shuffle: bool = True, - ) -> DataLoader: + def _data_loader(self, X: List[Any], y: List[int] = None, shuffle: bool = True) -> DataLoader: """ Generate a DataLoader for user-provided input and target data. diff --git a/examples/hed/hed_bridge.py b/examples/hed/hed_bridge.py index e93d46c..b0f401f 100644 --- a/examples/hed/hed_bridge.py +++ b/examples/hed/hed_bridge.py @@ -19,17 +19,17 @@ class HEDBridge(SimpleBridge): def __init__( self, model: ABLModel, - abducer: ReasonerBase, + reasoner: ReasonerBase, metric_list: BaseMetric, ) -> None: - super().__init__(model, abducer, metric_list) + super().__init__(model, reasoner, metric_list) def pretrain(self, weights_dir): if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")): print_log("Pretrain Start", logger="current") cls_autoencoder = SymbolNetAutoencoder( - num_classes=len(self.abducer.kb.pseudo_label_list) + num_classes=len(self.reasoner.kb.pseudo_label_list) ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") criterion = torch.nn.MSELoss() @@ -74,7 +74,7 @@ class HEDBridge(SimpleBridge): max_revision=-1, require_more_revision=0, ): - return self.abducer.abduce( + return self.reasoner.abduce( (pred_label, pred_prob, pseudo_label, Y), max_revision, require_more_revision, @@ -86,8 +86,8 @@ class HEDBridge(SimpleBridge): pred_pseudo_label_list = [] abduced_pseudo_label_list = [] for _mapping in candidate_mappings: - self.abducer.mapping = _mapping - self.abducer.set_remapping() + self.reasoner.mapping = _mapping + self.reasoner.set_remapping() pred_pseudo_label = self.label_to_pseudo_label(pred_label) abduced_pseudo_label = self.abduce_pseudo_label( pred_label, pred_prob, pred_pseudo_label, Y, 20 @@ -100,8 +100,8 @@ class HEDBridge(SimpleBridge): max_revisible_instances = max(mapping_score) return_idx = mapping_score.index(max_revisible_instances) - self.abducer.mapping = candidate_mappings[return_idx] - self.abducer.set_remapping() + self.reasoner.mapping = candidate_mappings[return_idx] + self.reasoner.set_remapping() return abduced_pseudo_label_list[return_idx] def check_training_impact(self, filtered_X, filtered_abduced_label, X): @@ -137,7 +137,7 @@ class HEDBridge(SimpleBridge): pred_pseudo_label = self.label_to_pseudo_label(pred_label) consistent_num = sum( [ - self.abducer.kb.consist_rule(instance, rule) + self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label ] ) @@ -159,11 +159,11 @@ class HEDBridge(SimpleBridge): pred_pseudo_label = self.label_to_pseudo_label(pred_label) consistent_instance = [] for instance in pred_pseudo_label: - if self.abducer.kb.logic_forward([instance]): + if self.reasoner.kb.logic_forward([instance]): consistent_instance.append(instance) if len(consistent_instance) != 0: - rule = self.abducer.abduce_rules(consistent_instance) + rule = self.reasoner.abduce_rules(consistent_instance) if rule != None: rules.append(rule) break @@ -280,7 +280,7 @@ class HEDBridge(SimpleBridge): else: if equation_len == min_len: print_log( - "Learned mapping is: " + str(self.abducer.mapping), + "Learned mapping is: " + str(self.reasoner.mapping), logger="current", ) self.model.load(load_path="./weights/pretrain_weights.pth") diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index 932a25e..482fbd1 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -2,10 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ + "import sys\n", + "\n", + "sys.setrecursionlimit(10000)\n", + "\n", "import torch\n", "import numpy as np\n", "import torch.nn as nn\n", @@ -23,9 +27,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/16 20:43:38 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the HWF example.\n" + ] + } + ], "source": [ "# Initialize logger and print basic information\n", "print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n", @@ -45,21 +57,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "# Initialize knowledge base and abducer\n", + "# Initialize knowledge base and reasoner\n", "class HWF_KB(KBBase):\n", - " def __init__(\n", - " self, \n", - " pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", - " prebuild_GKB=False,\n", - " GKB_len_list=[1, 3, 5, 7],\n", - " max_err=1e-3,\n", - " use_cache=True\n", - " ):\n", - " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", "\n", " def _valid_candidate(self, formula):\n", " if len(formula) % 2 == 0:\n", @@ -79,8 +82,8 @@ " formula = [mapping[f] for f in formula]\n", " return eval(''.join(formula))\n", "\n", - "kb = HWF_KB(prebuild_GKB=True)\n", - "abducer = ReasonerBase(kb, dist_func='confidence')" + "kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n", + "reasoner = ReasonerBase(kb, dist_func='confidence')" ] }, { @@ -93,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -106,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -146,12 +149,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Add metric\n", - "metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]" + "metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]" ] }, { @@ -164,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -183,11 +186,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)" + "bridge = SimpleBridge(model=model, reasoner=reasoner, metric_list=metric_list)" ] }, { @@ -200,11 +203,123 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [1/10] model loss is 0.16911\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [2/10] model loss is 0.17734\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [3/10] model loss is 0.01907\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [4/10] model loss is 0.01403\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [5/10] model loss is 0.00509\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [6/10] model loss is 0.00713\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [7/10] model loss is 0.00455\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [8/10] model loss is 0.00946\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [9/10] model loss is 0.00957\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [10/10] model loss is 0.00323\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.997 hwf/semantics_accuracy: 0.985 \n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_1.pth\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [1/10] model loss is 0.00666\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [2/10] model loss is 0.01438\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [3/10] model loss is 0.00450\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [4/10] model loss is 0.00764\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [5/10] model loss is 0.00644\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [6/10] model loss is 0.00189\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [7/10] model loss is 0.00397\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [8/10] model loss is 0.00936\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [9/10] model loss is 0.00960\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [10/10] model loss is 0.00572\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.999 hwf/semantics_accuracy: 0.995 \n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_2.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [1/10] model loss is 0.00180\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [2/10] model loss is 0.00615\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [3/10] model loss is 0.01000\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [4/10] model loss is 0.00415\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [5/10] model loss is 0.00960\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [6/10] model loss is 0.00697\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [7/10] model loss is 0.00977\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [8/10] model loss is 0.00734\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [9/10] model loss is 0.00922\n", + "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [10/10] model loss is 0.00982\n", + "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.998 hwf/semantics_accuracy: 0.986 \n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_3.pth\n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.994 hwf/semantics_accuracy: 0.970 \n" + ] + } + ], "source": [ - "bridge.train(train_data, epochs=3, batch_size=1000)\n", + "bridge.train(train_data, loops=3, segment_size=1000, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" ] } diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 0927cb5..845424b 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -50,7 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize knowledge base and abducer\n", + "# Initialize knowledge base and reasoner\n", "class add_KB(KBBase):\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", @@ -58,7 +58,7 @@ "kb = add_KB(pseudo_label_list=list(range(10)))\n", "\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", - "abducer = ReasonerBase(kb, dist_func=\"confidence\")" + "reasoner = ReasonerBase(kb, dist_func=\"confidence\")" ] }, { @@ -171,7 +171,7 @@ "metadata": {}, "outputs": [], "source": [ - "bridge = SimpleBridge(model, abducer, metric)" + "bridge = SimpleBridge(model, reasoner, metric)" ] }, {