diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index 03054f7..4ea87ac 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -1,52 +1,64 @@ from abc import ABCMeta, abstractmethod -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple, Union from ..learning import ABLModel from ..reasoning import ReasonerBase +from ..structures import ListData +DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] -class BaseBridge(metaclass=ABCMeta): - def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: +class BaseBridge(metaclass=ABCMeta): + def __init__(self, model: ABLModel, reasoner: ReasonerBase) -> None: if not isinstance(model, ABLModel): - raise TypeError("Expected an ABLModel") - if not isinstance(abducer, ReasonerBase): - raise TypeError("Expected an ReasonerBase") - + raise TypeError( + "Expected an instance of ABLModel, but received type: {}".format( + type(model) + ) + ) + if not isinstance(reasoner, ReasonerBase): + raise TypeError( + "Expected an instance of ReasonerBase, but received type: {}".format( + type(reasoner) + ) + ) + self.model = model - self.abducer = abducer + self.reasoner = reasoner @abstractmethod - def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: + def predict( + self, data_samples: ListData + ) -> Tuple[List[List[Any]], List[List[Any]]]: """Placeholder for predict labels from input.""" pass @abstractmethod - def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for abduce pseudo labels.""" + pass @abstractmethod - def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: + def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for map label space to symbol space.""" pass @abstractmethod - def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for map symbol space to label space.""" pass - + @abstractmethod - def train(self, train_data): + def train(self, train_data: Union[ListData, DataSet]): """Placeholder for train loop of ABductive Learning.""" pass @abstractmethod - def test(self, test_data): + def valid(self, valid_data: Union[ListData, DataSet]) -> None: """Placeholder for model test.""" pass @abstractmethod - def valid(self, valid_data): + def test(self, test_data: Union[ListData, DataSet]) -> None: """Placeholder for model validation.""" pass - \ No newline at end of file diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 3aecffd..ff33376 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -1,104 +1,119 @@ -from ..learning import ABLModel -from ..reasoning import ReasonerBase -from ..evaluation import BaseMetric -from .base_bridge import BaseBridge -from typing import List, Union, Any, Tuple, Dict, Optional +import os.path as osp +from typing import Any, Dict, List, Optional, Tuple, Union + from numpy import ndarray -from torch.utils.data import DataLoader -from ..dataset import BridgeDataset -from ..utils.logger import print_log +from ..evaluation import BaseMetric +from ..learning import ABLModel +from ..reasoning import ReasonerBase +from ..structures import ListData +from ..utils import print_log +from .base_bridge import BaseBridge, DataSet class SimpleBridge(BaseBridge): def __init__( self, model: ABLModel, - abducer: ReasonerBase, + reasoner: ReasonerBase, metric_list: List[BaseMetric], ) -> None: - super().__init__(model, abducer) + super().__init__(model, reasoner) self.metric_list = metric_list - def predict(self, X) -> Tuple[List[List[Any]], ndarray]: - pred_res = self.model.predict(X) - pred_idx, pred_prob = pred_res["label"], pred_res["prob"] - return pred_idx, pred_prob - + # TODO: add reasoner.mapping to the property of SimpleBridge + + def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: + self.model.predict(data_samples) + return data_samples.pred_idx, data_samples.pred_prob + def abduce_pseudo_label( self, - pred_prob: ndarray, - pred_pseudo_label: List[List[Any]], - Y: List[Any], + data_samples: ListData, max_revision: int = -1, require_more_revision: int = 0, ) -> List[List[Any]]: - return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) + self.reasoner.batch_abduce(data_samples, max_revision, require_more_revision) + return data_samples.abduced_pseudo_label def idx_to_pseudo_label( - self, idx: List[List[Any]], mapping: Dict = None + self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: - mapping = self.abducer.mapping - return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] + mapping = self.reasoner.mapping + pred_idx = data_samples.pred_idx + data_samples.pred_pseudo_label = [ + [mapping[_idx] for _idx in sub_list] for sub_list in pred_idx + ] + return data_samples.pred_pseudo_label def pseudo_label_to_idx( - self, pseudo_label: List[List[Any]], mapping: Dict = None + self, data_samples: ListData, mapping: Optional[Dict] = None ) -> List[List[Any]]: if mapping is None: - mapping = self.abducer.remapping - return [ - [mapping[_pseudo_label] for _pseudo_label in sub_list] - for sub_list in pseudo_label + mapping = self.reasoner.remapping + abduced_idx = [ + [mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] + for sub_list in data_samples.abduced_pseudo_label ] + data_samples.abduced_idx = abduced_idx + return data_samples.abduced_idx + + def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: + data_samples = ListData() + + data_samples.X = X + data_samples.gt_pseudo_label = gt_pseudo_label + data_samples.Y = Y + + return data_samples def train( self, - train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], - epochs: int = 50, - batch_size: Union[int, float] = -1, + train_data: Union[ListData, DataSet], + loops: int = 50, + segment_size: Union[int, float] = -1, eval_interval: int = 1, + save_interval: Optional[int] = None, + save_dir: Optional[str] = None, ): - dataset = BridgeDataset(*train_data) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], - ) - - for epoch in range(epochs): - for seg_idx, (X, Z, Y) in enumerate(data_loader): - pred_idx, pred_prob = self.predict(X) - pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) - abduced_pseudo_label = self.abduce_pseudo_label( - pred_prob, pred_pseudo_label, Y - ) - abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) - loss = self.model.train(X, abduced_label) + if isinstance(train_data, ListData): + data_samples = train_data + else: + data_samples = self.data_preprocess(*train_data) + + for loop in range(loops): + for seg_idx in range((len(data_samples) - 1) // segment_size + 1): + sub_data_samples = data_samples[ + seg_idx * segment_size : (seg_idx + 1) * segment_size + ] + self.predict(sub_data_samples) + self.idx_to_pseudo_label(sub_data_samples) + self.abduce_pseudo_label(sub_data_samples) + self.pseudo_label_to_idx(sub_data_samples) + loss = self.model.train(sub_data_samples) print_log( - f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", + f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", logger="current", ) - if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: - print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") + if (loop + 1) % eval_interval == 0 or loop == loops - 1: + print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") self.valid(train_data) - def _valid(self, data_loader): - for X, Z, Y in data_loader: - pred_idx, pred_prob = self.predict(X) - pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) - data_samples = dict( - pred_idx=pred_idx, - pred_prob=pred_prob, - pred_pseudo_label=pred_pseudo_label, - gt_pseudo_label=Z, - Y=Y, - logic_forward=self.abducer.kb.logic_forward, - ) - for metric in self.metric_list: - metric.process(data_samples) + if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): + print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") + self.model.save( + save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") + ) + + def _valid(self, data_samples: ListData) -> None: + self.predict(data_samples) + self.idx_to_pseudo_label(data_samples) + + for metric in self.metric_list: + metric.process(data_samples) res = dict() for metric in self.metric_list: @@ -108,14 +123,12 @@ class SimpleBridge(BaseBridge): msg += k + f": {v:.3f} " print_log(msg, logger="current") - def valid(self, valid_data, batch_size=1000): - dataset = BridgeDataset(*valid_data) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], - ) - self._valid(data_loader) - - def test(self, test_data, batch_size=1000): - self.valid(test_data, batch_size) + def valid(self, valid_data: Union[ListData, DataSet]) -> None: + if not isinstance(valid_data, ListData): + data_samples = self.data_preprocess(*valid_data) + else: + data_samples = valid_data + self._valid(data_samples) + + def test(self, test_data: Union[ListData, DataSet]) -> None: + self.valid(test_data) diff --git a/abl/dataset/__init__.py b/abl/dataset/__init__.py index 6be0df1..a487476 100644 --- a/abl/dataset/__init__.py +++ b/abl/dataset/__init__.py @@ -1,3 +1,4 @@ from .bridge_dataset import BridgeDataset from .classification_dataset import ClassificationDataset -from .regression_dataset import RegressionDataset \ No newline at end of file +from .prediction_dataset import PredictionDataset +from .regression_dataset import RegressionDataset diff --git a/abl/dataset/bridge_dataset.py b/abl/dataset/bridge_dataset.py index bb0ce98..a7d32c5 100644 --- a/abl/dataset/bridge_dataset.py +++ b/abl/dataset/bridge_dataset.py @@ -1,5 +1,6 @@ +from typing import Any, List, Tuple + from torch.utils.data import Dataset -from typing import List, Any, Tuple class BridgeDataset(Dataset): diff --git a/abl/dataset/classification_dataset.py b/abl/dataset/classification_dataset.py index 28f9299..1663642 100644 --- a/abl/dataset/classification_dataset.py +++ b/abl/dataset/classification_dataset.py @@ -1,6 +1,7 @@ +from typing import Any, Callable, List, Tuple + import torch from torch.utils.data import Dataset -from typing import List, Any, Tuple, Callable class ClassificationDataset(Dataset): diff --git a/abl/dataset/prediction_dataset.py b/abl/dataset/prediction_dataset.py new file mode 100644 index 0000000..8e3c717 --- /dev/null +++ b/abl/dataset/prediction_dataset.py @@ -0,0 +1,56 @@ +from typing import Any, Callable, List, Tuple + +import torch +from torch.utils.data import Dataset + + +class PredictionDataset(Dataset): + def __init__(self, X: List[Any], transform: Callable[..., Any] = None): + """ + Initialize the dataset used for classification task. + + Parameters + ---------- + X : List[Any] + The input data. + transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version. Defaults to None. + """ + if not isinstance(X, list): + raise ValueError("X should be of type list.") + + self.X = X + self.transform = transform + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: + """ + Get the item at the given index. + + Parameters + ---------- + index : int + The index of the item to get. + + Returns + ------- + Tuple[Any, torch.Tensor] + A tuple containing the object and its label. + """ + if index >= len(self): + raise ValueError("index range error") + + x = self.X[index] + if self.transform is not None: + x = self.transform(x) + return x diff --git a/abl/dataset/regression_dataset.py b/abl/dataset/regression_dataset.py index 8cf136c..118ac65 100644 --- a/abl/dataset/regression_dataset.py +++ b/abl/dataset/regression_dataset.py @@ -1,6 +1,7 @@ +from typing import Any, List, Tuple + import torch from torch.utils.data import Dataset -from typing import List, Any, Tuple class RegressionDataset(Dataset): diff --git a/abl/evaluation/__init__.py b/abl/evaluation/__init__.py index a849d68..3106412 100644 --- a/abl/evaluation/__init__.py +++ b/abl/evaluation/__init__.py @@ -1,3 +1,3 @@ from .base_metric import BaseMetric -from .symbol_metric import SymbolMetric from .semantics_metric import SemanticsMetric +from .symbol_metric import SymbolMetric diff --git a/abl/evaluation/base_metric.py b/abl/evaluation/base_metric.py index 44364f8..e18f452 100644 --- a/abl/evaluation/base_metric.py +++ b/abl/evaluation/base_metric.py @@ -1,8 +1,8 @@ +import logging from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Sequence -from ..utils import print_log -import logging +from ..utils import print_log class BaseMetric(metaclass=ABCMeta): diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 3333daf..14c4f46 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,25 +1,24 @@ from typing import Optional, Sequence + +from ..reasoning import KBBase from .base_metric import BaseMetric -class ABLMetric(): - pass class SemanticsMetric(BaseMetric): - def __init__(self, prefix: Optional[str] = None) -> None: + def __init__(self, kb: KBBase = None, prefix: Optional[str] = None) -> None: super().__init__(prefix) + self.kb = kb def process(self, data_samples: Sequence[dict]) -> None: - pred_pseudo_label = data_samples["pred_pseudo_label"] - gt_Y = data_samples["Y"] - logic_forward = data_samples["logic_forward"] - - for pred_z, y in zip(pred_pseudo_label, gt_Y): - if logic_forward(pred_z) == y: + pred_pseudo_label_list = data_samples.pred_pseudo_label + y_list = data_samples.Y + for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list): + if self.kb._check_equal(self.kb.logic_forward(pred_pseudo_label), y): self.results.append(1) else: self.results.append(0) - + def compute_metrics(self, results: list) -> dict: metrics = dict() metrics["semantics_accuracy"] = sum(results) / len(results) - return metrics \ No newline at end of file + return metrics diff --git a/abl/evaluation/symbol_metric.py b/abl/evaluation/symbol_metric.py index 3c0c216..a160bb2 100644 --- a/abl/evaluation/symbol_metric.py +++ b/abl/evaluation/symbol_metric.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence, Callable +from typing import Optional, Sequence + from .base_metric import BaseMetric @@ -7,9 +8,9 @@ class SymbolMetric(BaseMetric): super().__init__(prefix) def process(self, data_samples: Sequence[dict]) -> None: - pred_pseudo_label = data_samples["pred_pseudo_label"] + pred_pseudo_label = data_samples.pred_pseudo_label - gt_pseudo_label = data_samples["gt_pseudo_label"] + gt_pseudo_label = data_samples.gt_pseudo_label if not len(pred_pseudo_label) == len(gt_pseudo_label): raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index f1512fe..bcf03df 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -10,8 +10,10 @@ # # ================================================================# import pickle -from utils import flatten, reform_idx -from typing import List, Any, Optional +from typing import Any, Dict + +from ..structures import ListData +from ..utils import reform_list class ABLModel: @@ -30,7 +32,7 @@ class ABLModel: Methods ------- - predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict + predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict Predict the labels and probabilities for the given data. valid(X: List[List[Any]], Y: List[Any]) -> float Calculate the accuracy score for the given data. @@ -42,20 +44,13 @@ class ABLModel: Load the model from a file. """ - def __init__(self, base_model) -> None: - self.classifier_list = [] - self.classifier_list.append(base_model) + def __init__(self, base_model: Any) -> None: + if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): + raise NotImplementedError("The base_model should implement fit and predict methods.") - if not ( - hasattr(base_model, "fit") - and hasattr(base_model, "predict") - and hasattr(base_model, "score") - ): - raise NotImplementedError( - "base_model should have fit, predict and score methods." - ) + self.base_model = base_model - def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: + def predict(self, data_samples: ListData) -> Dict: """ Predict the labels and probabilities for the given data. @@ -63,53 +58,29 @@ class ABLModel: ---------- X : List[List[Any]] The data to predict on. - mapping : Optional[dict], optional - A mapping dictionary to map labels to their original values, by default None. Returns ------- dict A dictionary containing the predicted labels and probabilities. """ - model = self.classifier_list[0] - data_X = flatten(X) + model = self.base_model + data_X = data_samples.flatten("X") if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) - prob = reform_idx(prob, X) + prob = reform_list(prob, data_samples.X) else: prob = None label = model.predict(X=data_X) + label = reform_list(label, data_samples.X) - if mapping is not None: - label = [mapping[y] for y in label] - - label = reform_idx(label, X) + data_samples.pred_idx = label + data_samples.pred_prob = prob return {"label": label, "prob": prob} - def valid(self, X: List[List[Any]], Y: List[Any]) -> float: - """ - Calculate the accuracy for the given data. - - Parameters - ---------- - X : List[List[Any]] - The data to calculate the accuracy on. - Y : List[Any] - The true labels for the given data. - - Returns - ------- - float - The accuracy score for the given data. - """ - data_X = flatten(X) - data_Y = flatten(Y) - score = self.classifier_list[0].score(X=data_X, y=data_Y) - return score - - def train(self, X: List[List[Any]], Y: List[Any]) -> float: + def train(self, data_samples: ListData) -> float: """ Train the model on the given data. @@ -125,29 +96,30 @@ class ABLModel: float The loss value of the trained model. """ - data_X = flatten(X) - data_Y = flatten(Y) - return self.classifier_list[0].fit(X=data_X, y=data_Y) + data_X = data_samples.flatten("X") + data_y = data_samples.flatten("abduced_idx") + return self.base_model.fit(X=data_X, y=data_y) def _model_operation(self, operation: str, *args, **kwargs): - model = self.classifier_list[0] + model = self.base_model if hasattr(model, operation): method = getattr(model, operation) method(*args, **kwargs) else: - try: - if not f"{operation}_path" in kwargs.keys(): - raise ValueError(f"'{operation}_path' should not be None") - if operation == "save": - with open(kwargs["save_path"], 'wb') as file: - pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) - elif operation == "load": - with open(kwargs["load_path"], 'rb') as file: - self.classifier_list[0] = pickle.load(file) - except: - raise NotImplementedError( - f"{type(model).__name__} object doesn't have the {operation} method" - ) + if not f"{operation}_path" in kwargs.keys(): + raise ValueError(f"'{operation}_path' should not be None") + else: + try: + if operation == "save": + with open(kwargs["save_path"], "wb") as file: + pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) + elif operation == "load": + with open(kwargs["load_path"], "rb") as file: + self.base_model = pickle.load(file) + except: + raise NotImplementedError( + f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." + ) def save(self, *args, **kwargs) -> None: """ diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index b9b0b36..115b098 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -10,14 +10,16 @@ # # ================================================================# -import torch +import os +import logging +from typing import Any, Callable, List, Optional, T, Tuple + import numpy +import torch from torch.utils.data import DataLoader -from ..utils.logger import print_log -from ..dataset import ClassificationDataset -import os -from typing import List, Any, T, Optional, Callable, Tuple +from ..dataset import ClassificationDataset, PredictionDataset +from ..utils.logger import print_log class BasicNN: @@ -64,7 +66,8 @@ class BasicNN: num_workers: int = 0, save_interval: Optional[int] = None, save_dir: Optional[str] = None, - transform: Callable[..., Any] = None, + train_transform: Callable[..., Any] = None, + test_transform: Callable[..., Any] = None, collate_fn: Callable[[List[T]], Any] = None, ) -> None: self.model = model.to(device) @@ -77,10 +80,19 @@ class BasicNN: self.num_workers = num_workers self.save_interval = save_interval self.save_dir = save_dir - self.transform = transform + self.train_transform = train_transform + self.test_transform = test_transform self.collate_fn = collate_fn - def _fit(self, data_loader) -> float: + if self.train_transform is not None and self.test_transform is None: + print_log( + "Transform used in the training phase will be used in prediction.", + "current", + level=logging.WARNING, + ) + self.test_transform = self.train_transform + + def _fit(self, data_loader: DataLoader) -> float: """ Internal method to fit the model on data for n epochs, with early stopping. @@ -99,9 +111,7 @@ class BasicNN: loss_value = self.train_epoch(data_loader) if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: if self.save_dir is None: - raise ValueError( - "save_dir should not be None if save_interval is not None." - ) + raise ValueError("save_dir should not be None if save_interval is not None.") self.save(epoch + 1) if self.stop_loss is not None and loss_value < self.stop_loss: break @@ -170,7 +180,7 @@ class BasicNN: return total_loss / total_num - def _predict(self, data_loader) -> torch.Tensor: + def _predict(self, data_loader: DataLoader) -> torch.Tensor: """ Internal method to predict the outputs given a DataLoader. @@ -191,16 +201,14 @@ class BasicNN: with torch.no_grad(): results = [] - for data, _ in data_loader: + for data in data_loader: data = data.to(device) out = model(data) results.append(out) return torch.cat(results, axis=0) - def predict( - self, data_loader: DataLoader = None, X: List[Any] = None - ) -> numpy.ndarray: + def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: """ Predict the class of the input data. @@ -218,12 +226,16 @@ class BasicNN: """ if data_loader is None: - data_loader = self._data_loader(X) + dataset = PredictionDataset(X, self.test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=int(self.num_workers), + collate_fn=self.collate_fn, + ) return self._predict(data_loader).argmax(axis=1).cpu().numpy() - def predict_proba( - self, data_loader: DataLoader = None, X: List[Any] = None - ) -> numpy.ndarray: + def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: """ Predict the probability of each class for the input data. @@ -241,10 +253,16 @@ class BasicNN: """ if data_loader is None: - data_loader = self._data_loader(X) + dataset = PredictionDataset(X, self.test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=int(self.num_workers), + collate_fn=self.collate_fn, + ) return self._predict(data_loader).softmax(axis=1).cpu().numpy() - def _score(self, data_loader) -> Tuple[float, float]: + def _score(self, data_loader: DataLoader) -> Tuple[float, float]: """ Internal method to compute loss and accuracy for the data provided through a DataLoader. @@ -313,16 +331,10 @@ class BasicNN: if data_loader is None: data_loader = self._data_loader(X, y) mean_loss, accuracy = self._score(data_loader) - print_log( - f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current" - ) + print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") return accuracy - def _data_loader( - self, - X: List[Any], - y: List[int] = None, - ) -> DataLoader: + def _data_loader(self, X: List[Any], y: List[int] = None, shuffle: bool = True) -> DataLoader: """ Generate a DataLoader for user-provided input and target data. @@ -346,11 +358,11 @@ class BasicNN: if not (len(y) == len(X)): raise ValueError("X and y should have equal length.") - dataset = ClassificationDataset(X, y, transform=self.transform) + dataset = ClassificationDataset(X, y, transform=self.train_transform) data_loader = DataLoader( dataset, batch_size=self.batch_size, - shuffle=True, + shuffle=shuffle, num_workers=int(self.num_workers), collate_fn=self.collate_fn, ) @@ -368,14 +380,13 @@ class BasicNN: The path to save the model, by default None. """ if self.save_dir is None and save_path is None: - raise ValueError( - "'save_dir' and 'save_path' should not be None simultaneously." - ) + raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.") - if save_path is None: - save_path = os.path.join( - self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth" - ) + if save_path is not None: + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) + else: + save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth") if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) diff --git a/abl/reasoning/__init__.py b/abl/reasoning/__init__.py index 8930758..f2ff627 100644 --- a/abl/reasoning/__init__.py +++ b/abl/reasoning/__init__.py @@ -1,2 +1,2 @@ -from .reasoner import ReasonerBase -from .kb import KBBase, prolog_KB \ No newline at end of file +from .kb import KBBase, GroundKB, PrologKB +from .reasoner import ReasonerBase \ No newline at end of file diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 37ba5b6..8ef6a25 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -9,7 +9,8 @@ from functools import lru_cache import numpy as np import pyswip -from abl.utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable +from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable, restore_from_hashable +from ..utils.cache import abl_cache class KBBase(ABC): @@ -21,35 +22,46 @@ class KBBase(ABC): pseudo_label_list : list List of possible pseudo labels. max_err : float, optional - The upper tolerance limit when comparing the similarity between a candidate's logical - result. This is only applicable when the logical result is of a numerical type. - This is particularly relevant for regression problems where exact matches might not be - feasible. Defaults to 1e-10. + The upper tolerance limit when comparing the similarity between a candidate's logical + result. This is only applicable when the logical result is of a numerical type. + This is particularly relevant for regression problems where exact matches might not be + feasible. Defaults to 1e-10. use_cache : bool, optional - Whether to use a cache for previously abduced candidates to speed up subsequent + Whether to use a cache for previously abduced candidates to speed up subsequent operations. Defaults to True. - + Notes ----- - Users should inherit from this base class to build their own knowledge base. For the - user-build KB (an inherited subclass), it's only required for the user to provide the - `pseudo_label_list` and override the `logic_forward` function (specifying how to - perform logical reasoning). After that, other operations (e.g. how to perform abductive - reasoning) will be automatically set up. + Users should inherit from this base class to build their own knowledge base. For the + user-build KB (an inherited subclass), it's only required for the user to provide the + `pseudo_label_list` and override the `logic_forward` function (specifying how to + perform logical reasoning). After that, other operations (e.g. how to perform abductive + reasoning) will be automatically set up. """ - def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True): + + def __init__( + self, + pseudo_label_list, + max_err=1e-10, + use_cache=True, + key_func=to_hashable, + max_cache_size=4096, + ): if not isinstance(pseudo_label_list, list): raise TypeError("pseudo_label_list should be list") self.pseudo_label_list = pseudo_label_list self.max_err = max_err - self.use_cache = use_cache + + self.use_cache = use_cache + self.key_func = key_func + self.max_cache_size = max_cache_size @abstractmethod def logic_forward(self, pseudo_label): """ - How to perform (deductive) logical reasoning, i.e. matching each pseudo label to + How to perform (deductive) logical reasoning, i.e. matching each pseudo label to their logical result. Users are required to provide this. - + Parameters ---------- pred_pseudo_label : List[Any] @@ -70,23 +82,22 @@ class KBBase(ABC): max_revision_num : int The upper limit on the number of revisions. require_more_revision : int, optional - Specifies additional number of revisions permitted beyond the minimum required. + Specifies additional number of revisions permitted beyond the minimum required. Defaults to 0. Returns ------- List[List[Any]] - A list of candidates, i.e. revised pseudo labels that are consistent with the + A list of candidates, i.e. revised pseudo labels that are consistent with the knowledge base. """ - if self.use_cache: - return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), - to_hashable(y), - max_revision_num, require_more_revision) - else: - return self._abduce_by_search(pred_pseudo_label, y, - max_revision_num, require_more_revision) - + # if self.use_cache: + # return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), + # to_hashable(y), + # max_revision_num, require_more_revision) + # else: + return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) + def _check_equal(self, logic_result, y): """ Check whether the logical result of a candidate is equal to the ground truth @@ -94,12 +105,12 @@ class KBBase(ABC): """ if logic_result == None: return False - + if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): return abs(logic_result - y) <= self.max_err else: return logic_result == y - + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): """ Revise the predicted pseudo label at specified index positions. @@ -125,7 +136,7 @@ class KBBase(ABC): def _revision(self, revision_num, pred_pseudo_label, y): """ - For a specified number of pseudo label to revise, iterate through all possible + For a specified number of pseudo label to revise, iterate through all possible indices to find any candidates that are consistent with the knowledge base. """ new_candidates = [] @@ -136,12 +147,13 @@ class KBBase(ABC): new_candidates.extend(candidates) return new_candidates - def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + @abl_cache() + def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): """ - Perform abductive reasoning by exhastive search. Specifically, begin with 0 and - continuously increase the number of pseudo labels to revise, until candidates + Perform abductive reasoning by exhastive search. Specifically, begin with 0 and + continuously increase the number of pseudo labels to revise, until candidates that are consistent with the knowledge base are found. - + Parameters ---------- pred_pseudo_label : List[Any] @@ -151,16 +163,16 @@ class KBBase(ABC): max_revision_num : int The upper limit on the number of revisions. require_more_revision : int - If larger than 0, then after having found any candidates consistent with the - knowledge base, continue to increase the number pseudo labels to revise to + If larger than 0, then after having found any candidates consistent with the + knowledge base, continue to increase the number pseudo labels to revise to get more possible consistent candidates. Returns ------- List[List[Any]] - A list of candidates, i.e. revised pseudo label that are consistent with the + A list of candidates, i.e. revised pseudo label that are consistent with the knowledge base. - """ + """ candidates = [] for revision_num in range(len(pred_pseudo_label) + 1): if revision_num == 0 and self._check_equal(self.logic_forward(pred_pseudo_label), y): @@ -173,20 +185,22 @@ class KBBase(ABC): if revision_num >= max_revision_num: return [] - for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): + for revision_num in range( + min_revision_num + 1, min_revision_num + require_more_revision + 1 + ): if revision_num > max_revision_num: return candidates candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) return candidates - - @lru_cache(maxsize=4096) - def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): - """ - `_abduce_by_search` with cache. - """ - pred_pseudo_label = restore_from_hashable(pred_pseudo_label) - y = restore_from_hashable(y) - return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) + + # @abl_cache(max_size=4096) + # def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + # """ + # `_abduce_by_search` with cache. + # """ + # pred_pseudo_label = restore_from_hashable(pred_pseudo_label) + # y = restore_from_hashable(y) + # return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) def __repr__(self): return ( @@ -195,13 +209,13 @@ class KBBase(ABC): f"max_err={self.max_err!r}, " f"use_cache={self.use_cache!r}." ) - - + + class GroundKB(KBBase): """ - Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon - class initialization, storing all potential candidates along with their respective - logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. + Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon + class initialization, storing all potential candidates along with their respective + logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. Parameters ---------- @@ -211,15 +225,16 @@ class GroundKB(KBBase): List of possible lengths of pseudo label. max_err : float, optional Refer to class `KBBase`. - + Notes ----- - Users can also inherit from this class to build their own knowledge base. Similar - to `KBBase`, users are only required to provide the `pseudo_label_list` and override + Users can also inherit from this class to build their own knowledge base. Similar + to `KBBase`, users are only required to provide the `pseudo_label_list` and override the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. - After that, other operations (e.g. auto-construction of GKB, and how to perform + After that, other operations (e.g. auto-construction of GKB, and how to perform abductive reasoning) will be automatically set up. """ + def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10): super().__init__(pseudo_label_list, max_err) if not isinstance(GKB_len_list, list): @@ -229,7 +244,6 @@ class GroundKB(KBBase): X, Y = self._get_GKB() for x, y in zip(X, Y): self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) - def _get_XY_list(self, args): pre_x, post_x_it = args[0], args[1] @@ -259,21 +273,21 @@ class GroundKB(KBBase): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) - if Y and isinstance(Y[0], (int, float)): + if Y and isinstance(Y[0], (int, float)): X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) return X, Y - + def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): """ - Perform abductive reasoning by directly retrieving consistent candidates from - the prebuilt GKB. In this way, the time-consuming exhaustive search can be + Perform abductive reasoning by directly retrieving consistent candidates from + the prebuilt GKB. In this way, the time-consuming exhaustive search can be avoided. - This is an overridden function. For more information about the parameters and + This is an overridden function. For more information about the parameters and returns, refer to the function of the same name in class `KBBase`. """ if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list: return [] - + all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) if len(all_candidates) == 0: return [] @@ -284,29 +298,30 @@ class GroundKB(KBBase): idxs = np.where(cost_list <= revision_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates - + def _find_candidate_GKB(self, pred_pseudo_label, y): """ - Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, - return all candidates whose logical results fall within the + Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, + return all candidates whose logical results fall within the [y - max_err, y + max_err] range. """ if isinstance(y, (int, float)): potential_candidates = self.GKB[len(pred_pseudo_label)] key_list = list(potential_candidates.keys()) - + low_key = bisect.bisect_left(key_list, y - self.max_err) high_key = bisect.bisect_right(key_list, y + self.max_err) - all_candidates = [candidate - for key in key_list[low_key:high_key] - for candidate in potential_candidates[key]] + all_candidates = [ + candidate + for key in key_list[low_key:high_key] + for candidate in potential_candidates[key] + ] return all_candidates - + else: return self.GKB[len(pred_pseudo_label)][y] - - + def __repr__(self): return ( f"{self.__class__.__name__} is a KB with " @@ -321,78 +336,80 @@ class GroundKB(KBBase): class PrologKB(KBBase): """ Knowledge base provided by a Prolog (.pl) file. - + Parameters ---------- pseudo_label_list : list Refer to class `KBBase`. - pl_file : - Prolog file containing the KB. + pl_file : + Prolog file containing the KB. max_err : float, optional Refer to class `KBBase`. - + Notes ----- - Users can instantiate this class to build their own knowledge base. During the + Users can instantiate this class to build their own knowledge base. During the instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`. - To use the default logic forward and abductive reasoning methods in this class, in the - Prolog (.pl) file, there needs to be a rule which is strictly formatted as + To use the default logic forward and abductive reasoning methods in this class, in the + Prolog (.pl) file, there needs to be a rule which is strictly formatted as `logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`. - For specifics, refer to the `logic_forward` and `get_query_string` functions in this + For specifics, refer to the `logic_forward` and `get_query_string` functions in this class. Users are also welcome to override related functions for more flexible support. """ + def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list) self.pl_file = pl_file self.prolog = pyswip.Prolog() - + if not os.path.exists(self.pl_file): raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") self.prolog.consult(self.pl_file) def logic_forward(self, pseudo_labels): """ - Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the - returned `Res` as the logical results. To use this default function, there must be - a Prolog `log_forward` method in the pl file to perform logical. reasoning. + Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the + returned `Res` as the logical results. To use this default function, there must be + a Prolog `log_forward` method in the pl file to perform logical. reasoning. Otherwise, users would override this function. """ - result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] - if result == 'true': + result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"] + if result == "true": return True - elif result == 'false': + elif result == "false": return False return result - + def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx): import re + revision_pred_pseudo_label = pred_pseudo_label.copy() revision_pred_pseudo_label = flatten(revision_pred_pseudo_label) - + for idx in revision_idx: - revision_pred_pseudo_label[idx] = 'P' + str(idx) - revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) - + revision_pred_pseudo_label[idx] = "P" + str(idx) + revision_pred_pseudo_label = reform_list(revision_pred_pseudo_label, pred_pseudo_label) + regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) - + def get_query_string(self, pred_pseudo_label, y, revision_idx): """ - Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set - the returned `Revise_labels` together with the kept labels as the candidates. This is + Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set + the returned `Revise_labels` together with the kept labels as the candidates. This is a default fuction for demo, users would override this function to adapt to their own - Prolog file. + Prolog file. """ query_string = "logic_forward(" query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) key_is_none_flag = y is None or (type(y) == list and y[0] is None) query_string += ",%s)." % y if not key_is_none_flag else ")." return query_string - + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): """ Revise the predicted pseudo label at specified index positions by querying Prolog. - This is an overridden function. For more information about the parameters, refer to + This is an overridden function. For more information about the parameters, refer to the function of the same name in class `KBBase`. """ candidates = [] @@ -404,7 +421,7 @@ class PrologKB(KBBase): candidate = pred_pseudo_label.copy() for i, idx in enumerate(revision_idx): candidate[idx] = c[i] - candidate = reform_idx(candidate, save_pred_pseudo_label) + candidate = reform_list(candidate, save_pred_pseudo_label) candidates.append(candidate) return candidates @@ -414,4 +431,4 @@ class PrologKB(KBBase): f"pseudo_label_list={self.pseudo_label_list!r}, " f"defined by " f"Prolog file {self.pl_file!r}." - ) \ No newline at end of file + ) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 9cc24f0..0786d06 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,9 +1,9 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from abl.utils.utils import ( +from ..utils.utils import ( confidence_dist, flatten, - reform_idx, + reform_list, hamming_dist, ) @@ -191,7 +191,7 @@ class ReasonerBase: return max_revision def abduce( - self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 + self, data_sample, max_revision=-1, require_more_revision=0 ): """ Perform abductive reasoning on the given prediction data. @@ -219,9 +219,13 @@ class ReasonerBase: A revised pseudo label through abductive reasoning, which is consistent with the knowledge base. """ - symbol_num = len(flatten(pred_pseudo_label)) + symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - + + pred_pseudo_label = data_sample.pred_pseudo_label + pred_prob = data_sample.pred_prob + y = data_sample.Y + if self.use_zoopt: solution = self.zoopt_get_solution( symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num @@ -237,20 +241,18 @@ class ReasonerBase: return candidate def batch_abduce( - self, pred_probs, pred_pseudo_labels, Ys, max_revision=-1, require_more_revision=0 + self, data_samples, max_revision=-1, require_more_revision=0 ): """ Perform abductive reasoning on the given prediction data in batches. For detailed information, refer to `abduce`. """ - return [ - self.abduce( - pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision - ) - for pred_prob, pred_pseudo_label, Y in zip( - pred_probs, pred_pseudo_labels, Ys - ) + abduced_pseudo_label = [ + self.abduce(data_sample, max_revision, require_more_revision) + for data_sample in data_samples ] + data_samples.abduced_pseudo_label = abduced_pseudo_label + return abduced_pseudo_label # def _batch_abduce_helper(self, args): # z, prob, y, max_revision, require_more_revision = args @@ -273,12 +275,11 @@ class ReasonerBase: if __name__ == "__main__": from kb import KBBase, GroundKB, PrologKB - - prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + from abl.structures import ListData - prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + ################################ + # Test for MNIST Add reasoning # + ################################ class AddKB(KBBase): def __init__(self, pseudo_label_list=list(range(10)), @@ -288,38 +289,54 @@ if __name__ == "__main__": def logic_forward(self, nums): return sum(nums) - class AddGroundKB(GroundKB): + class AddGroundKB(GroundKB, AddKB): def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): super().__init__(pseudo_label_list, GKB_len_list) + + def logic_forward(self, nums): + return sum(nums) + + def logic_forward(self, nums): return sum(nums) def test_add(reasoner): - res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + # favor 1 in first one + prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + + # favor 7 in first one + prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + + data_samples_add = ListData() + data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] + data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] + data_samples_add.Y = [8, 8, 17, 10] + + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) + print(res) + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) print(res) + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) + print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9] print() - print("AddKB with GKB:") + print("AddGroundKB:") kb = AddGroundKB() reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) - print("AddKB without GKB:") + print("AddKB:") kb = AddKB() reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) - print("AddKB without GKB, no cache") + print("AddKB, no cache") kb = AddKB(use_cache=False) reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) @@ -337,45 +354,20 @@ if __name__ == "__main__": ) reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) test_add(reasoner) - - print("AddKB with multiple inputs at once:") - multiple_prob = [[ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ], - [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ]] - - kb = AddKB() - reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=1, - ) - print(res) - print() - + + ################################ + #### Test for HWF reasoning #### + ################################ + class HwfKB(KBBase): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"], max_err=1e-3, + use_cache=False, ): - super().__init__(pseudo_label_list, max_err) + super().__init__(pseudo_label_list, max_err, use_cache) def _valid_candidate(self, formula): if len(formula) % 2 == 0: @@ -395,7 +387,7 @@ if __name__ == "__main__": formula = [mapping[f] for f in formula] return eval("".join(formula)) - class HwfGroundKB(GroundKB): + class HwfGroundKB(GroundKB, HwfKB): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", @@ -405,6 +397,17 @@ if __name__ == "__main__": ): super().__init__(pseudo_label_list, GKB_len_list, max_err) + + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: + return False + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: + return False + return True + def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False @@ -415,6 +418,16 @@ if __name__ == "__main__": return False return True + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return None + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) + formula = [mapping[f] for f in formula] + return eval("".join(formula)) + + def logic_forward(self, formula): if not self._valid_candidate(formula): return None @@ -424,87 +437,46 @@ if __name__ == "__main__": return eval("".join(formula)) def test_hwf(reasoner): - res = reasoner.batch_abduce( - [None], - [["5", "+", "2"]], - [3], - max_revision=2, - require_more_revision=0, - ) + data_samples_hwf = ListData() + data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] + data_samples_hwf.pred_prob = [None, None, None, None] + data_samples_hwf.Y = [3, 64, 65, 3.17] + + res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) print(res) - res = reasoner.batch_abduce( - [None], - [["5", "+", "9"]], - [65], - max_revision=3, - require_more_revision=0, - ) + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) print(res) - res = reasoner.batch_abduce( - [None], - [["5", "8", "8", "8", "8"]], - [3.17], - max_revision=5, - require_more_revision=3, - ) + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) print(res) print() - def test_hwf_multiple(reasoner, max_revisions): - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[0], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[1], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 65], - max_revision=max_revisions[2], - require_more_revision=0, - ) - print(res) - print() - print("HwfKB with GKB, max_err=0.1") + print("HwfGroundKB, max_err=0.1:") kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB without GKB, max_err=0.1") + print("HwfKB, max_err=0.1:") kb = HwfKB(max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB with GKB, max_err=1") + print("HwfGroundKB, max_err=1:") kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB without GKB, max_err=1") + print("HwfKB, max_err=1:") kb = HwfKB(max_err=1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - - print("HwfKB with multiple inputs at once:") - kb = HwfKB(max_err=0.1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf_multiple(reasoner, max_revisions=[1,3,3]) - print("max_revision is float") - test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) - + + ################################ + #### Test for HED reasoning #### + ################################ + + class HedKB(PrologKB): def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list, pl_file) @@ -540,7 +512,7 @@ if __name__ == "__main__": return candidate def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): - all_revision_flag = reform_idx(sol.get_x(), pred_res) + all_revision_flag = reform_list(sol.get_x(), pred_res) lefted_idxs = [i for i in range(len(pred_res))] candidate_size = [] while lefted_idxs: @@ -597,28 +569,24 @@ if __name__ == "__main__": inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] - print("HedKB logic forward") - print(kb.logic_forward(consist_exs)) + print("HedKB logic forward:") + print(kb.logic_forward(consist_exs), end=" ") print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) print() - print("HedKB consist rule") - print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) + print("HedKB consist rule:") + print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ") print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) print() + data_sample_hed = ListData() + data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2] + data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] + data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] + print("HedReasoner abduce") - res = reasoner.abduce( - [[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) - ) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) - ) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) - ) - print(res) + res = reasoner.batch_abduce(data_sample_hed) + for r in res: + print(r) print() print("HedReasoner abduce rules") diff --git a/abl/structures/__init__.py b/abl/structures/__init__.py new file mode 100644 index 0000000..52b5af3 --- /dev/null +++ b/abl/structures/__init__.py @@ -0,0 +1,2 @@ +from .base_data_element import BaseDataElement +from .list_data import ListData \ No newline at end of file diff --git a/abl/structures/base_data_element.py b/abl/structures/base_data_element.py new file mode 100644 index 0000000..03176d1 --- /dev/null +++ b/abl/structures/base_data_element.py @@ -0,0 +1,629 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Any, Iterator, Optional, Tuple, Type, Union + +import numpy as np +import torch + + +class BaseDataElement: + """A base data interface that supports Tensor-like and dict-like + operations. + + A typical data elements refer to predicted results or ground truth labels + on a task, such as predicted bboxes, instance masks, semantic + segmentation masks, etc. Because groundtruth labels and predicted results + often have similar properties (for example, the predicted bboxes and the + groundtruth bboxes), MMEngine uses the same abstract data interface to + encapsulate predicted results and groundtruth labels, and it is recommended + to use different name conventions to distinguish them, such as using + ``gt_instances`` and ``pred_instances`` to distinguish between labels and + predicted results. Additionally, we distinguish data elements at instance + level, pixel level, and label level. Each of these types has its own + characteristics. Therefore, MMEngine defines the base class + ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and + ``LabelData`` inheriting from ``BaseDataElement`` to represent different + types of ground truth labels or predictions. + + Another common data element is sample data. A sample data consists of input + data (such as an image) and its annotations and predictions. In general, + an image can have multiple types of annotations and/or predictions at the + same time (for example, both pixel-level semantic segmentation annotations + and instance-level detection bboxes annotations). All labels and + predictions of a training sample are often passed between Dataset, Model, + Visualizer, and Evaluator components. In order to simplify the interface + between components, we can treat them as a large data element and + encapsulate them. Such data elements are generally called XXDataSample in + the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` + allows `BaseDataElement` as its attribute. Such a class generally + encapsulates all the data of a sample in the algorithm library, and its + attributes generally are various types of data elements. For example, + MMDetection is assigned by the BaseDataElement to encapsulate all the data + elements of the sample labeling and prediction of a sample in the + algorithm library. + + The attributes in ``BaseDataElement`` are divided into two parts, + the ``metainfo`` and the ``data`` respectively. + + - ``metainfo``: Usually contains the + information about the image such as filename, + image_shape, pad_shape, etc. The attributes can be accessed or + modified by dict-like or object-like operations, such as + ``.`` (for data access and modification), ``in``, ``del``, + ``pop(str)``, ``get(str)``, ``metainfo_keys()``, + ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for + set or change key-value pairs in metainfo). + + - ``data``: Annotations or model predictions are + stored. The attributes can be accessed or modified by + dict-like or object-like operations, such as + ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, + ``values()``, ``items()``. Users can also apply tensor-like + methods to all :obj:`torch.Tensor` in the ``data_fields``, + such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, + ``to_tensor()``, ``.detach()``. + + Args: + metainfo (dict, optional): A dict contains the meta information + of single image, such as ``dict(img_shape=(512, 512, 3), + scale_factor=(1, 1, 1, 1))``. Defaults to None. + kwargs (dict, optional): A dict contains annotations of single image or + model predictions. Defaults to None. + + Examples: + >>> import torch + >>> from mmengine.structures import BaseDataElement + >>> gt_instances = BaseDataElement() + >>> bboxes = torch.rand((5, 4)) + >>> scores = torch.rand((5,)) + >>> img_id = 0 + >>> img_shape = (800, 1333) + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=img_id, img_shape=img_shape), + ... bboxes=bboxes, scores=scores) + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) + + >>> # new + >>> gt_instances1 = gt_instances.new( + ... metainfo=dict(img_id=1, img_shape=(640, 640)), + ... bboxes=torch.rand((5, 4)), + ... scores=torch.rand((5,))) + >>> gt_instances2 = gt_instances1.new() + + >>> # add and process property + >>> gt_instances = BaseDataElement() + >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) + >>> assert 'img_shape' in gt_instances.metainfo_keys() + >>> assert 'img_shape' in gt_instances + >>> assert 'img_shape' not in gt_instances.keys() + >>> assert 'img_shape' in gt_instances.all_keys() + >>> print(gt_instances.img_shape) + (100, 100) + >>> gt_instances.scores = torch.rand((5,)) + >>> assert 'scores' in gt_instances.keys() + >>> assert 'scores' in gt_instances + >>> assert 'scores' in gt_instances.all_keys() + >>> assert 'scores' not in gt_instances.metainfo_keys() + >>> print(gt_instances.scores) + tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) + >>> gt_instances.bboxes = torch.rand((5, 4)) + >>> assert 'bboxes' in gt_instances.keys() + >>> assert 'bboxes' in gt_instances + >>> assert 'bboxes' in gt_instances.all_keys() + >>> assert 'bboxes' not in gt_instances.metainfo_keys() + >>> print(gt_instances.bboxes) + tensor([[0.0900, 0.0424, 0.1755, 0.4469], + [0.8648, 0.0592, 0.3484, 0.0913], + [0.5808, 0.1909, 0.6165, 0.7088], + [0.5490, 0.4209, 0.9416, 0.2374], + [0.3652, 0.1218, 0.8805, 0.7523]]) + + >>> # delete and change property + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=0, img_shape=(640, 640)), + ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) + >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) + >>> gt_instances.img_shape # (1280, 1280) + >>> gt_instances.bboxes = gt_instances.bboxes * 2 + >>> gt_instances.get('img_shape', None) # (1280, 1280) + >>> gt_instances.get('bboxes', None) # 6x4 tensor + >>> del gt_instances.img_shape + >>> del gt_instances.bboxes + >>> assert 'img_shape' not in gt_instances + >>> assert 'bboxes' not in gt_instances + >>> gt_instances.pop('img_shape', None) # None + >>> gt_instances.pop('bboxes', None) # None + + >>> # Tensor-like + >>> cuda_instances = gt_instances.cuda() + >>> cuda_instances = gt_instances.to('cuda:0') + >>> cpu_instances = cuda_instances.cpu() + >>> cpu_instances = cuda_instances.to('cpu') + >>> fp16_instances = cuda_instances.to( + ... device=None, dtype=torch.float16, non_blocking=False, + ... copy=False, memory_format=torch.preserve_format) + >>> cpu_instances = cuda_instances.detach() + >>> np_instances = cpu_instances.numpy() + + >>> # print + >>> metainfo = dict(img_shape=(800, 1196, 3)) + >>> gt_instances = BaseDataElement( + ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + >>> sample = BaseDataElement(metainfo=metainfo, + ... gt_instances=gt_instances) + >>> print(sample) + + ) at 0x7f0fea49e130> + + >>> # inheritance + >>> class DetDataSample(BaseDataElement): + ... @property + ... def proposals(self): + ... return self._proposals + ... @proposals.setter + ... def proposals(self, value): + ... self.set_field(value, '_proposals', dtype=BaseDataElement) + ... @proposals.deleter + ... def proposals(self): + ... del self._proposals + ... @property + ... def gt_instances(self): + ... return self._gt_instances + ... @gt_instances.setter + ... def gt_instances(self, value): + ... self.set_field(value, '_gt_instances', + ... dtype=BaseDataElement) + ... @gt_instances.deleter + ... def gt_instances(self): + ... del self._gt_instances + ... @property + ... def pred_instances(self): + ... return self._pred_instances + ... @pred_instances.setter + ... def pred_instances(self, value): + ... self.set_field(value, '_pred_instances', + ... dtype=BaseDataElement) + ... @pred_instances.deleter + ... def pred_instances(self): + ... del self._pred_instances + >>> det_sample = DetDataSample() + >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) + >>> det_sample.proposals = proposals + >>> assert 'proposals' in det_sample + >>> assert det_sample.proposals == proposals + >>> del det_sample.proposals + >>> assert 'proposals' not in det_sample + >>> with self.assertRaises(AssertionError): + ... det_sample.proposals = torch.rand((5, 4)) + """ + + def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: + self._metainfo_fields: set = set() + self._data_fields: set = set() + + if metainfo is not None: + self.set_metainfo(metainfo=metainfo) + if kwargs: + self.set_data(kwargs) + + def set_metainfo(self, metainfo: dict) -> None: + """Set or change key-value pairs in ``metainfo_field`` by parameter + ``metainfo``. + + Args: + metainfo (dict): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + """ + assert isinstance( + metainfo, dict + ), f"metainfo should be a ``dict`` but got {type(metainfo)}" + meta = copy.deepcopy(metainfo) + for k, v in meta.items(): + self.set_field(name=k, value=v, field_type="metainfo", dtype=None) + + def set_data(self, data: dict) -> None: + """Set or change key-value pairs in ``data_field`` by parameter + ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. + """ + assert isinstance(data, dict), f"data should be a `dict` but got {data}" + for k, v in data.items(): + # Use `setattr()` rather than `self.set_field` to allow `set_data` + # to set property method. + setattr(self, k, v) + + def update(self, instance: "BaseDataElement") -> None: + """The update() method updates the BaseDataElement with the elements + from another BaseDataElement object. + + Args: + instance (BaseDataElement): Another BaseDataElement object for + update the current object. + """ + assert isinstance( + instance, BaseDataElement + ), f"instance should be a `BaseDataElement` but got {type(instance)}" + self.set_metainfo(dict(instance.metainfo_items())) + self.set_data(dict(instance.items())) + + def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": + """Return a new data element with same type. If ``metainfo`` and + ``data`` are None, the new data element will have same metainfo and + data. If metainfo or data is not None, the new result will overwrite it + with the input value. + + Args: + metainfo (dict, optional): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + Defaults to None. + kwargs (dict): A dict contains annotations of image or + model predictions. + + Returns: + BaseDataElement: A new data element with same type. + """ + new_data = self.__class__() + + if metainfo is not None: + new_data.set_metainfo(metainfo) + else: + new_data.set_metainfo(dict(self.metainfo_items())) + if kwargs: + new_data.set_data(kwargs) + else: + new_data.set_data(dict(self.items())) + return new_data + + def clone(self): + """Deep copy the current data element. + + Returns: + BaseDataElement: The copy of current data element. + """ + clone_data = self.__class__() + clone_data.set_metainfo(dict(self.metainfo_items())) + clone_data.set_data(dict(self.items())) + return clone_data + + def keys(self) -> list: + """ + Returns: + list: Contains all keys in data_fields. + """ + # We assume that the name of the attribute related to property is + # '_' + the name of the property. We use this rule to filter out + # private keys. + # TODO: Use a more robust way to solve this problem + private_keys = { + "_" + key + for key in self._data_fields + if isinstance(getattr(type(self), key, None), property) + } + return list(self._data_fields - private_keys) + + def metainfo_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo_fields. + """ + return list(self._metainfo_fields) + + def values(self) -> list: + """ + Returns: + list: Contains all values in data. + """ + return [getattr(self, k) for k in self.keys()] + + def metainfo_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo. + """ + return [getattr(self, k) for k in self.metainfo_keys()] + + def all_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo and data. + """ + return self.metainfo_keys() + self.keys() + + def all_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo and data. + """ + return self.metainfo_values() + self.values() + + def all_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``metainfo`` and ``data``. + """ + for k in self.all_keys(): + yield (k, getattr(self, k)) + + def items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``data``. + """ + for k in self.keys(): + yield (k, getattr(self, k)) + + def metainfo_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: An iterator object whose element is (key, value) tuple + pairs for ``metainfo``. + """ + for k in self.metainfo_keys(): + yield (k, getattr(self, k)) + + @property + def metainfo(self) -> dict: + """dict: A dict contains metainfo of current data element.""" + return dict(self.metainfo_items()) + + def __setattr__(self, name: str, value: Any): + """setattr is only used to set data.""" + if name in ("_metainfo_fields", "_data_fields"): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f"{name} has been used as a " + "private attribute, which is immutable." + ) + else: + self.set_field(name=name, value=value, field_type="data", dtype=None) + + def __delattr__(self, item: str): + """Delete the item in dataelement. + + Args: + item (str): The key to delete. + """ + if item in ("_metainfo_fields", "_data_fields"): + raise AttributeError( + f"{item} has been used as a " "private attribute, which is immutable." + ) + super().__delattr__(item) + if item in self._metainfo_fields: + self._metainfo_fields.remove(item) + elif item in self._data_fields: + self._data_fields.remove(item) + + # dict-like methods + __delitem__ = __delattr__ + + def get(self, key, default=None) -> Any: + """Get property in data and metainfo as the same as python.""" + # Use `getattr()` rather than `self.__dict__.get()` to allow getting + # properties. + return getattr(self, key, default) + + def pop(self, *args) -> Any: + """Pop property in data and metainfo as the same as python.""" + assert len(args) < 3, "``pop`` get more than 2 arguments" + name = args[0] + if name in self._metainfo_fields: + self._metainfo_fields.remove(args[0]) + return self.__dict__.pop(*args) + + elif name in self._data_fields: + self._data_fields.remove(args[0]) + return self.__dict__.pop(*args) + + # with default value + elif len(args) == 2: + return args[1] + else: + # don't just use 'self.__dict__.pop(*args)' for only popping key in + # metainfo or data + raise KeyError(f"{args[0]} is not contained in metainfo or data") + + def __contains__(self, item: str) -> bool: + """Whether the item is in dataelement. + + Args: + item (str): The key to inquire. + """ + return item in self._data_fields or item in self._metainfo_fields + + def set_field( + self, + value: Any, + name: str, + dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, + field_type: str = "data", + ) -> None: + """Special method for set union field, used as property.setter + functions.""" + assert field_type in ["metainfo", "data"] + if dtype is not None: + assert isinstance( + value, dtype + ), f"{value} should be a {dtype} but got {type(value)}" + + if field_type == "metainfo": + if name in self._data_fields: + raise AttributeError( + f"Cannot set {name} to be a field of metainfo " + f"because {name} is already a data field" + ) + self._metainfo_fields.add(name) + else: + if name in self._metainfo_fields: + raise AttributeError( + f"Cannot set {name} to be a field of data " + f"because {name} is already a metainfo field" + ) + self._data_fields.add(name) + super().__setattr__(name, value) + + # Tensor-like methods + def to(self, *args, **kwargs) -> "BaseDataElement": + """Apply same name function to all tensors in data_fields.""" + new_data = self.new() + for k, v in self.items(): + if hasattr(v, "to"): + v = v.to(*args, **kwargs) + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def cpu(self) -> "BaseDataElement": + """Convert all tensors to CPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cpu() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def cuda(self) -> "BaseDataElement": + """Convert all tensors to GPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cuda() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def npu(self) -> "BaseDataElement": + """Convert all tensors to NPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.npu() + data = {k: v} + new_data.set_data(data) + return new_data + + def mlu(self) -> "BaseDataElement": + """Convert all tensors to MLU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.mlu() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def detach(self) -> "BaseDataElement": + """Detach all tensors in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach() + data = {k: v} + new_data.set_data(data) + return new_data + + # Tensor-like methods + def numpy(self) -> "BaseDataElement": + """Convert all tensors to np.ndarray in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach().cpu().numpy() + data = {k: v} + new_data.set_data(data) + return new_data + + def to_tensor(self) -> "BaseDataElement": + """Convert all np.ndarray to tensor in data.""" + new_data = self.new() + for k, v in self.items(): + data = {} + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data[k] = v + elif isinstance(v, BaseDataElement): + v = v.to_tensor() + data[k] = v + new_data.set_data(data) + return new_data + + def to_dict(self) -> dict: + """Convert BaseDataElement to dict.""" + return { + k: v.to_dict() if isinstance(v, BaseDataElement) else v + for k, v in self.all_items() + } + + def __repr__(self) -> str: + """Represent the object.""" + + def _addindent(s_: str, num_spaces: int) -> str: + """This func is modified from `pytorch` https://github.com/pytorch/ + pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu + les/module.py#L29. + + Args: + s_ (str): The string to add spaces. + num_spaces (int): The num of space to add. + + Returns: + str: The string after add indent. + """ + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) # type: ignore + s = first + "\n" + s # type: ignore + return s # type: ignore + + def dump(obj: Any) -> str: + """Represent the object. + + Args: + obj (Any): The obj to represent. + + Returns: + str: The represented str. + """ + _repr = "" + if isinstance(obj, dict): + for k, v in obj.items(): + _repr += f"\n{k}: {_addindent(dump(v), 4)}" + elif isinstance(obj, BaseDataElement): + _repr += "\n\n META INFORMATION" + metainfo_items = dict(obj.metainfo_items()) + _repr += _addindent(dump(metainfo_items), 4) + _repr += "\n\n DATA FIELDS" + items = dict(obj.items()) + _repr += _addindent(dump(items), 4) + classname = obj.__class__.__name__ + _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" + else: + _repr += repr(obj) + return _repr + + return dump(self) diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py new file mode 100644 index 0000000..a53ffc5 --- /dev/null +++ b/abl/structures/list_data.py @@ -0,0 +1,305 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from collections.abc import Sized +from typing import Any, List, Union + +import numpy as np +import torch + +from ..utils import flatten as flatten_list +from ..utils import to_hashable +from .base_data_element import BaseDataElement + +BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] +LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] + +IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] + + +# Modified from +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa +class ListData(BaseDataElement): + """Data structure for instance-level annotations or predictions. + + Subclass of :class:`BaseDataElement`. All value in `data_fields` + should have the same length. This design refer to + https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 + ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value + in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, + and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. + + Examples: + >>> # custom data structure + >>> class TmpObject: + ... def __init__(self, tmp) -> None: + ... assert isinstance(tmp, list) + ... self.tmp = tmp + ... def __len__(self): + ... return len(self.tmp) + ... def __getitem__(self, item): + ... if isinstance(item, int): + ... if item >= len(self) or item < -len(self): # type:ignore + ... raise IndexError(f'Index {item} out of range!') + ... else: + ... # keep the dimension + ... item = slice(item, None, len(self)) + ... return TmpObject(self.tmp[item]) + ... @staticmethod + ... def cat(tmp_objs): + ... assert all(isinstance(results, TmpObject) for results in tmp_objs) + ... if len(tmp_objs) == 1: + ... return tmp_objs[0] + ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] + ... tmp_list = list(itertools.chain(*tmp_list)) + ... new_data = TmpObject(tmp_list) + ... return new_data + ... def __repr__(self): + ... return str(self.tmp) + >>> from mmengine.structures import ListData + >>> import numpy as np + >>> import torch + >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + >>> instance_data = ListData(metainfo=img_meta) + >>> 'img_shape' in instance_data + True + >>> instance_data.det_labels = torch.LongTensor([2, 3]) + >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) + >>> instance_data.bboxes = torch.rand((2, 4)) + >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) + >>> len(instance_data) + 2 + >>> print(instance_data) + + >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] + >>> sorted_results.det_scores + tensor([0.7000, 0.8000]) + >>> print(instance_data[instance_data.det_scores > 0.75]) + + >>> print(instance_data[instance_data.det_scores > 1]) + + >>> print(instance_data.cat([instance_data, instance_data])) + + """ + + def __setattr__(self, name: str, value: list): + """setattr is only used to set data. + + The value must have the attribute of `__len__` and have the same length + of `ListData`. + """ + if name in ("_metainfo_fields", "_data_fields"): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f"{name} has been used as a " "private attribute, which is immutable." + ) + + else: + # assert isinstance(value, list), "value must be of type `list`" + + # if len(self) > 0: + # assert len(value) == len(self), ( + # "The length of " + # f"values {len(value)} is " + # "not consistent with " + # "the length of this " + # ":obj:`ListData` " + # f"{len(self)}" + # ) + super().__setattr__(name, value) + + __setitem__ = __setattr__ + + def __getitem__(self, item: IndexType) -> "ListData": + """ + Args: + item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, + :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): + Get the corresponding values according to item. + + Returns: + :obj:`ListData`: Corresponding values. + """ + assert isinstance(item, IndexType.__args__) + if isinstance(item, list): + item = np.array(item) + if isinstance(item, np.ndarray): + # The default int type of numpy is platform dependent, int32 for + # windows and int64 for linux. `torch.Tensor` requires the index + # should be int64, therefore we simply convert it to int64 here. + # More details in https://github.com/numpy/numpy/issues/9464 + item = item.astype(np.int64) if item.dtype == np.int32 else item + item = torch.from_numpy(item) + + if isinstance(item, str): + return getattr(self, item) + + new_data = self.__class__(metainfo=self.metainfo) + + if isinstance(item, torch.Tensor): + assert item.dim() == 1, "Only support to get the" " values along the first dimension." + + for k, v in self.items(): + if v is None: + new_data[k] = None + elif isinstance(v, torch.Tensor): + new_data[k] = v[item] + elif isinstance(v, np.ndarray): + new_data[k] = v[item.cpu().numpy()] + elif isinstance(v, (str, list, tuple)) or ( + hasattr(v, "__getitem__") and hasattr(v, "cat") + ): + # convert to indexes from BoolTensor + if isinstance(item, BoolTypeTensor.__args__): + indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() + else: + indexes = item.cpu().numpy().tolist() + slice_list = [] + if indexes: + for index in indexes: + slice_list.append(slice(index, None, len(v))) + else: + slice_list.append(slice(None, 0, None)) + r_list = [v[s] for s in slice_list] + if isinstance(v, (str, list, tuple)): + new_value = r_list[0] + for r in r_list[1:]: + new_value = new_value + r + else: + new_value = v.cat(r_list) + new_data[k] = new_value + else: + raise ValueError( + f"The type of `{k}` is `{type(v)}`, which has no " + "attribute of `cat`, so it does not " + "support slice with `bool`" + ) + + else: + # item is a slice or int + for k, v in self.items(): + if v is None: + new_data[k] = None + else: + new_data[k] = v[item] + return new_data # type:ignore + + @staticmethod + def cat(instances_list: List["ListData"]) -> "ListData": + """Concat the instances of all :obj:`ListData` in the list. + + Note: To ensure that cat returns as expected, make sure that + all elements in the list must have exactly the same keys. + + Args: + instances_list (list[:obj:`ListData`]): A list + of :obj:`ListData`. + + Returns: + :obj:`ListData` + """ + assert all(isinstance(results, ListData) for results in instances_list) + assert len(instances_list) > 0 + if len(instances_list) == 1: + return instances_list[0] + + # metainfo and data_fields must be exactly the + # same for each element to avoid exceptions. + field_keys_list = [instances.all_keys() for instances in instances_list] + assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( + set(itertools.chain(*field_keys_list)) + ) == len(field_keys_list[0]), ( + "There are different keys in " + "`instances_list`, which may " + "cause the cat operation " + "to fail. Please make sure all " + "elements in `instances_list` " + "have the exact same key." + ) + + new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) + for k in instances_list[0].keys(): + values = [results[k] for results in instances_list] + v0 = values[0] + if isinstance(v0, torch.Tensor): + new_values = torch.cat(values, dim=0) + elif isinstance(v0, np.ndarray): + new_values = np.concatenate(values, axis=0) + elif isinstance(v0, (str, list, tuple)): + new_values = v0[:] + for v in values[1:]: + new_values += v + elif hasattr(v0, "cat"): + new_values = v0.cat(values) + else: + raise ValueError( + f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`" + ) + new_data[k] = new_values + return new_data # type:ignore + + def flatten(self, item: IndexType) -> List: + """Flatten self[item]. + + Returns: + list: Flattened data fields. + """ + return flatten_list(self[item]) + + def elements_num(self, item: IndexType) -> int: + """int: The number of elements in self[item].""" + return len(self.flatten(item)) + + def to_tuple(self, item: IndexType) -> tuple: + """tuple: The data fields in self[item] converted to tuple.""" + return to_hashable(self[item]) + + def __len__(self) -> int: + """int: The length of ListData.""" + if len(self._data_fields) > 0: + one_element = next(iter(self._data_fields)) + return len(getattr(self, one_element)) + # return len(self.values()[0]) + else: + return 0 diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index 75c7990..bbd0d81 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,2 +1,3 @@ +from .cache import Cache, abl_cache from .logger import ABLLogger, print_log -from .utils import * \ No newline at end of file +from .utils import * diff --git a/abl/utils/cache.py b/abl/utils/cache.py new file mode 100644 index 0000000..a927e1f --- /dev/null +++ b/abl/utils/cache.py @@ -0,0 +1,104 @@ +import pickle +import os +import os.path as osp +from typing import Callable, Generic, TypeVar + +from .logger import print_log, ABLLogger + +K = TypeVar("K") +T = TypeVar("T") +PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields + + +class Cache(Generic[K, T]): + def __init__(self, func: Callable[[K], T]): + """Create cache + + :param func: Function this cache evaluates + :param cache: If true, do in memory caching. + :param cache_root: If not None, cache to files at the provided path. + :param key_func: Convert the key into a hashable object if needed + """ + self.func = func + self.has_init = False + + def __getitem__(self, obj, *args) -> T: + return self.get_from_dict(obj, *args) + + def clear_cache(self): + """Invalidate entire cache.""" + self.cache_dict.clear() + + def _init_cache(self, obj): + if self.has_init: + return + + self.cache = True + self.cache_dict = dict() + self.key_func = obj.key_func + self.max_size = obj.max_cache_size + + self.hits, self.misses = 0, 0 + self.full = False + self.root = [] # root of the circular doubly linked list + self.root[:] = [self.root, self.root, None, None] + + self.has_init = True + + def get_from_dict(self, obj, *args) -> T: + """Implements dict based cache.""" + pred_pseudo_label, y, *res_args = args + cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) + link = self.cache_dict.get(cache_key) + if link is not None: + # Move the link to the front of the circular queue + link_prev, link_next, _key, result = link + link_prev[NEXT] = link_next + link_next[PREV] = link_prev + last = self.root[PREV] + last[NEXT] = self.root[PREV] = link + link[PREV] = last + link[NEXT] = self.root + self.hits += 1 + return result + self.misses += 1 + + result = self.func(obj, *args) + + if self.full: + # Use the old root to store the new key and result. + oldroot = self.root + oldroot[KEY] = cache_key + oldroot[RESULT] = result + # Empty the oldest link and make it the new root. + self.root = oldroot[NEXT] + oldkey = self.root[KEY] + oldresult = self.root[RESULT] + self.root[KEY] = self.root[RESULT] = None + # Now update the cache dictionary. + del self.cache_dict[oldkey] + self.cache_dict[cache_key] = oldroot + else: + # Put result in a new link at the front of the queue. + last = self.root[PREV] + link = [last, self.root, cache_key, result] + last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link + if isinstance(self.max_size, int): + self.full = len(self.cache_dict) >= self.max_size + return result + + +def abl_cache(): + def decorator(func): + cache_instance = Cache(func) + + def wrapper(obj, *args): + if obj.use_cache: + cache_instance._init_cache(obj) + return cache_instance.get_from_dict(obj, *args) + else: + return func(obj, *args) + + return wrapper + + return decorator diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 485cf1a..6e3bb4f 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -1,6 +1,7 @@ -import numpy as np from itertools import chain +import numpy as np + def flatten(nested_list): """ @@ -15,6 +16,11 @@ def flatten(nested_list): ------- list A flattened version of the input list. + + Raises + ------ + TypeError + If the input object is not a list. """ if not isinstance(nested_list, list): raise TypeError("Input must be of type list.") @@ -25,7 +31,7 @@ def flatten(nested_list): return list(chain.from_iterable(nested_list)) -def reform_idx(flattened_list, structured_list): +def reform_list(flattened_list, structured_list): """ Reform the index based on structured_list structure. @@ -41,6 +47,9 @@ def reform_idx(flattened_list, structured_list): list A reformed list that mimics the structure of structured_list. """ + # if not isinstance(flattened_list, list): + # raise TypeError("Input must be of type list.") + if not isinstance(structured_list[0], (list, tuple)): return flattened_list @@ -80,7 +89,7 @@ def hamming_dist(pred_pseudo_label, candidates): return np.sum(pred_pseudo_label != candidates, axis=1) -def confidence_dist(pred_prob, candidates_idx): +def confidence_dist(pred_prob, candidates): """ Compute the confidence distance between prediction probabilities and candidates. @@ -89,7 +98,7 @@ def confidence_dist(pred_prob, candidates_idx): pred_prob : list of numpy.ndarray Prediction probability distributions, each element is an ndarray representing the probability distribution of a particular prediction. - candidates_idx : list of list of int + candidates : list of list of int Index of candidate labels, each element is a list of indexes being considered as a candidate correction. @@ -99,8 +108,8 @@ def confidence_dist(pred_prob, candidates_idx): Confidence distances computed for each candidate. """ pred_prob = np.clip(pred_prob, 1e-9, 1) - _, cols = np.indices((len(candidates_idx), len(candidates_idx[0]))) - return 1 - np.prod(pred_prob[cols, candidates_idx], axis=1) + _, cols = np.indices((len(candidates), len(candidates[0]))) + return 1 - np.prod(pred_prob[cols, candidates], axis=1) def block_sample(X, Z, Y, sample_num, seg_idx): @@ -154,6 +163,7 @@ def to_hashable(x): return tuple(to_hashable(item) for item in x) return x + def restore_from_hashable(x): """ Convert a nested tuple back to a nested list. @@ -170,10 +180,49 @@ def restore_from_hashable(x): otherwise the original input. """ if isinstance(x, tuple): - return [hashable_to_list(item) for item in x] + return [restore_from_hashable(item) for item in x] return x +def calculate_revision_num(parameter, total_length): + """ + Convert a float parameter to an integer, based on a total length. + + Parameters + ---------- + parameter : int or float + The parameter to convert. If float, it should be between 0 and 1. + If int, it should be non-negative. If -1, it will be replaced with total_length. + total_length : int + The total length to calculate the parameter from if it's a fraction. + + Returns + ------- + int + The calculated parameter. + + Raises + ------ + TypeError + If parameter is not an int or a float. + ValueError + If parameter is a float not in [0, 1] or an int below 0. + """ + if not isinstance(parameter, (int, float)): + raise TypeError("Parameter must be of type int or float.") + + if parameter == -1: + return total_length + elif isinstance(parameter, float): + if not (0 <= parameter <= 1): + raise ValueError("If parameter is a float, it must be between 0 and 1.") + return round(total_length * parameter) + else: + if parameter < 0: + raise ValueError("If parameter is an int, it must be non-negative.") + return parameter + + if __name__ == "__main__": A = np.array( [ @@ -227,4 +276,5 @@ if __name__ == "__main__": ) B = [[0, 9, 3], [0, 11, 4]] + print(ori_confidence_dist(A, B)) print(confidence_dist(A, B)) diff --git a/examples/hed/hed_bridge.py b/examples/hed/hed_bridge.py index e93d46c..b0f401f 100644 --- a/examples/hed/hed_bridge.py +++ b/examples/hed/hed_bridge.py @@ -19,17 +19,17 @@ class HEDBridge(SimpleBridge): def __init__( self, model: ABLModel, - abducer: ReasonerBase, + reasoner: ReasonerBase, metric_list: BaseMetric, ) -> None: - super().__init__(model, abducer, metric_list) + super().__init__(model, reasoner, metric_list) def pretrain(self, weights_dir): if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")): print_log("Pretrain Start", logger="current") cls_autoencoder = SymbolNetAutoencoder( - num_classes=len(self.abducer.kb.pseudo_label_list) + num_classes=len(self.reasoner.kb.pseudo_label_list) ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") criterion = torch.nn.MSELoss() @@ -74,7 +74,7 @@ class HEDBridge(SimpleBridge): max_revision=-1, require_more_revision=0, ): - return self.abducer.abduce( + return self.reasoner.abduce( (pred_label, pred_prob, pseudo_label, Y), max_revision, require_more_revision, @@ -86,8 +86,8 @@ class HEDBridge(SimpleBridge): pred_pseudo_label_list = [] abduced_pseudo_label_list = [] for _mapping in candidate_mappings: - self.abducer.mapping = _mapping - self.abducer.set_remapping() + self.reasoner.mapping = _mapping + self.reasoner.set_remapping() pred_pseudo_label = self.label_to_pseudo_label(pred_label) abduced_pseudo_label = self.abduce_pseudo_label( pred_label, pred_prob, pred_pseudo_label, Y, 20 @@ -100,8 +100,8 @@ class HEDBridge(SimpleBridge): max_revisible_instances = max(mapping_score) return_idx = mapping_score.index(max_revisible_instances) - self.abducer.mapping = candidate_mappings[return_idx] - self.abducer.set_remapping() + self.reasoner.mapping = candidate_mappings[return_idx] + self.reasoner.set_remapping() return abduced_pseudo_label_list[return_idx] def check_training_impact(self, filtered_X, filtered_abduced_label, X): @@ -137,7 +137,7 @@ class HEDBridge(SimpleBridge): pred_pseudo_label = self.label_to_pseudo_label(pred_label) consistent_num = sum( [ - self.abducer.kb.consist_rule(instance, rule) + self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label ] ) @@ -159,11 +159,11 @@ class HEDBridge(SimpleBridge): pred_pseudo_label = self.label_to_pseudo_label(pred_label) consistent_instance = [] for instance in pred_pseudo_label: - if self.abducer.kb.logic_forward([instance]): + if self.reasoner.kb.logic_forward([instance]): consistent_instance.append(instance) if len(consistent_instance) != 0: - rule = self.abducer.abduce_rules(consistent_instance) + rule = self.reasoner.abduce_rules(consistent_instance) if rule != None: rules.append(rule) break @@ -280,7 +280,7 @@ class HEDBridge(SimpleBridge): else: if equation_len == min_len: print_log( - "Learned mapping is: " + str(self.abducer.mapping), + "Learned mapping is: " + str(self.reasoner.mapping), logger="current", ) self.model.load(load_path="./weights/pretrain_weights.pth") diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index 932a25e..482fbd1 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -2,10 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ + "import sys\n", + "\n", + "sys.setrecursionlimit(10000)\n", + "\n", "import torch\n", "import numpy as np\n", "import torch.nn as nn\n", @@ -23,9 +27,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/16 20:43:38 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the HWF example.\n" + ] + } + ], "source": [ "# Initialize logger and print basic information\n", "print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n", @@ -45,21 +57,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "# Initialize knowledge base and abducer\n", + "# Initialize knowledge base and reasoner\n", "class HWF_KB(KBBase):\n", - " def __init__(\n", - " self, \n", - " pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", - " prebuild_GKB=False,\n", - " GKB_len_list=[1, 3, 5, 7],\n", - " max_err=1e-3,\n", - " use_cache=True\n", - " ):\n", - " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", "\n", " def _valid_candidate(self, formula):\n", " if len(formula) % 2 == 0:\n", @@ -79,8 +82,8 @@ " formula = [mapping[f] for f in formula]\n", " return eval(''.join(formula))\n", "\n", - "kb = HWF_KB(prebuild_GKB=True)\n", - "abducer = ReasonerBase(kb, dist_func='confidence')" + "kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n", + "reasoner = ReasonerBase(kb, dist_func='confidence')" ] }, { @@ -93,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -106,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -146,12 +149,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Add metric\n", - "metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]" + "metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]" ] }, { @@ -164,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -183,11 +186,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)" + "bridge = SimpleBridge(model=model, reasoner=reasoner, metric_list=metric_list)" ] }, { @@ -200,11 +203,123 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [1/10] model loss is 0.16911\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [2/10] model loss is 0.17734\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [3/10] model loss is 0.01907\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [4/10] model loss is 0.01403\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [5/10] model loss is 0.00509\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [6/10] model loss is 0.00713\n", + "11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [7/10] model loss is 0.00455\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [8/10] model loss is 0.00946\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [9/10] model loss is 0.00957\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [10/10] model loss is 0.00323\n", + "11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.997 hwf/semantics_accuracy: 0.985 \n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_1.pth\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [1/10] model loss is 0.00666\n", + "11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [2/10] model loss is 0.01438\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [3/10] model loss is 0.00450\n", + "11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [4/10] model loss is 0.00764\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [5/10] model loss is 0.00644\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [6/10] model loss is 0.00189\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [7/10] model loss is 0.00397\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [8/10] model loss is 0.00936\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [9/10] model loss is 0.00960\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [10/10] model loss is 0.00572\n", + "11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.999 hwf/semantics_accuracy: 0.995 \n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_2.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [1/10] model loss is 0.00180\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [2/10] model loss is 0.00615\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [3/10] model loss is 0.01000\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [4/10] model loss is 0.00415\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [5/10] model loss is 0.00960\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [6/10] model loss is 0.00697\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [7/10] model loss is 0.00977\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [8/10] model loss is 0.00734\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [9/10] model loss is 0.00922\n", + "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n", + "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [10/10] model loss is 0.00982\n", + "11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.998 hwf/semantics_accuracy: 0.986 \n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_3.pth\n", + "11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.994 hwf/semantics_accuracy: 0.970 \n" + ] + } + ], "source": [ - "bridge.train(train_data, epochs=3, batch_size=1000)\n", + "bridge.train(train_data, loops=3, segment_size=1000, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" ] } diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 146bd88..845424b 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -2,10 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "import os.path as osp\n", + "\n", "import torch.nn as nn\n", "import torch\n", "\n", @@ -13,21 +15,25 @@ "\n", "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", - "from abl.evaluation import SymbolMetric, ABLMetric\n", - "from abl.utils import ABLLogger\n", + "from abl.evaluation import SymbolMetric, SemanticsMetric\n", + "from abl.utils import ABLLogger, print_log\n", "\n", - "from models.nn import LeNet5\n", + "from examples.models.nn import LeNet5\n", "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize logger\n", - "logger = ABLLogger.get_instance(\"abl\")" + "print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", + "\n", + "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", + "log_dir = ABLLogger.get_current_instance().log_dir\n", + "weights_dir = osp.join(log_dir, \"weights\")" ] }, { @@ -40,22 +46,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# Initialize knowledge base and abducer\n", + "# Initialize knowledge base and reasoner\n", "class add_KB(KBBase):\n", - " def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", - " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", - "\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", - "kb = add_KB(prebuild_GKB=True)\n", + "kb = add_KB(pseudo_label_list=list(range(10)))\n", "\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", - "abducer = ReasonerBase(kb, dist_func=\"confidence\")" + "reasoner = ReasonerBase(kb, dist_func=\"confidence\")" ] }, { @@ -68,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -81,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -92,8 +95,6 @@ " criterion,\n", " optimizer,\n", " device,\n", - " save_interval=1,\n", - " save_dir=logger.save_dir,\n", " batch_size=32,\n", " num_epochs=1,\n", ")" @@ -109,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -129,12 +130,12 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Add metric\n", - "metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" + "metric = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]" ] }, { @@ -147,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -170,7 +171,7 @@ "metadata": {}, "outputs": [], "source": [ - "bridge = SimpleBridge(model, abducer, metric)" + "bridge = SimpleBridge(model, reasoner, metric)" ] }, { @@ -187,7 +188,7 @@ "metadata": {}, "outputs": [], "source": [ - "bridge.train(train_data, epochs=5, batch_size=10000)\n", + "bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" ] } @@ -208,7 +209,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" }, "orig_nbformat": 4, "vscode": {