| @@ -21,4 +21,4 @@ jobs: | |||
| uses: py-actions/flake8@v2 | |||
| with: | |||
| max-line-length: "100" | |||
| args: --ignore=E203,W503 | |||
| args: --ignore=E203,W503,F821,E266 | |||
| @@ -164,7 +164,8 @@ class SimpleBridge(BaseBridge): | |||
| self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData] | |||
| ) -> ListData: | |||
| """ | |||
| Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model. | |||
| Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data | |||
| examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model. | |||
| Parameters | |||
| ---------- | |||
| @@ -212,18 +213,19 @@ class SimpleBridge(BaseBridge): | |||
| Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` | |||
| object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. | |||
| - ``X`` is a list of sublists representing the input data. | |||
| - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but not | |||
| to train. ``gt_pseudo_label`` can be ``None``. | |||
| - ``Y`` is a list representing the ground truth reasoning result for each sublist in ``X``. | |||
| - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but | |||
| not to train. ``gt_pseudo_label`` can be ``None``. | |||
| - ``Y`` is a list representing the ground truth reasoning result for each sublist | |||
| in ``X``. | |||
| label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional | |||
| Labeled data should be in the same format as ``train_data``. The only difference is | |||
| that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be | |||
| utilized to train the model. Defaults to None. | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 | |||
| Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label`` | |||
| and ``Y`` can be either None or not, which depends on the evaluation metircs in | |||
| ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate the | |||
| model during training time. Defaults to None. | |||
| ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate | |||
| the model during training time. Defaults to None. | |||
| loops : int | |||
| Machine Learning part and Reasoning part will be iteratively optimized | |||
| for ``loops`` times, by default 50. | |||
| @@ -325,7 +327,7 @@ class SimpleBridge(BaseBridge): | |||
| Parameters | |||
| ---------- | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
| Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
| with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be | |||
| either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
| @@ -344,10 +346,10 @@ class SimpleBridge(BaseBridge): | |||
| Parameters | |||
| ---------- | |||
| test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] | |||
| Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object with ``X``, | |||
| ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be either None or | |||
| not, which depends on the evaluation metircs in ``self.metric_list``. | |||
| test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
| Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
| with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` | |||
| can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
| """ | |||
| print_log("Test start:", logger="current") | |||
| test_data_examples = self.data_preprocess("test", test_data) | |||
| @@ -4,21 +4,22 @@ from abl.utils import tab_data_to_tuple | |||
| from .structures.list_data import ListData | |||
| from lambdaLearn.Base.TabularMixin import TabularMixin | |||
| class DataConverter: | |||
| ''' | |||
| """ | |||
| This class provides functionality to convert LambdaLearn data to ABL-Package data. | |||
| ''' | |||
| """ | |||
| def __init__(self) -> None: | |||
| pass | |||
| def convert_lambdalearn_to_tuple( | |||
| self, | |||
| dataset: TabularMixin, | |||
| reasoning_result: Any | |||
| self, dataset: TabularMixin, reasoning_result: Any | |||
| ) -> Tuple[Tuple, Tuple, Tuple, Tuple]: | |||
| ''' | |||
| Convert a lambdalearn dataset to a tuple of tuples (label_data, train_data, valid_data, test_data), each containing (data, label, reasoning_result). | |||
| """ | |||
| Convert a lambdalearn dataset to a tuple of tuples (label_data, train_data, valid_data, test_data), # noqa: E501 | |||
| each containing (data, label, reasoning_result). | |||
| Parameters | |||
| ---------- | |||
| dataset : TabularMixin | |||
| @@ -28,27 +29,38 @@ class DataConverter: | |||
| Returns | |||
| ------- | |||
| Tuple[Tuple, Tuple, Tuple, Tuple] | |||
| A tuple of (label_data, train_data, valid_data, test_data), where each element is a tuple of (data, label, reasoning_result). | |||
| ''' | |||
| A tuple of (label_data, train_data, valid_data, test_data), where each element is | |||
| a tuple of (data, label, reasoning_result). | |||
| """ | |||
| if not isinstance(dataset, TabularMixin): | |||
| raise NotImplementedError("Only support converting the datasets that are instances of TabularMixin. Please refer to the documentation and manually convert the dataset into a tuple. ") | |||
| label_data = tab_data_to_tuple(dataset.labeled_X, dataset.labeled_y, reasoning_result=reasoning_result) | |||
| train_data = tab_data_to_tuple(dataset.unlabeled_X, dataset.unlabeled_y, reasoning_result=reasoning_result) | |||
| valid_data = tab_data_to_tuple(dataset.valid_X, dataset.valid_y, reasoning_result=reasoning_result) | |||
| test_data = tab_data_to_tuple(dataset.test_X, dataset.test_y, reasoning_result=reasoning_result) | |||
| raise NotImplementedError( | |||
| "Only support converting the datasets that are instances of TabularMixin. " | |||
| + "Please refer to the documentation and manually convert the dataset into a tuple." | |||
| ) | |||
| label_data = tab_data_to_tuple( | |||
| dataset.labeled_X, dataset.labeled_y, reasoning_result=reasoning_result | |||
| ) | |||
| train_data = tab_data_to_tuple( | |||
| dataset.unlabeled_X, dataset.unlabeled_y, reasoning_result=reasoning_result | |||
| ) | |||
| valid_data = tab_data_to_tuple( | |||
| dataset.valid_X, dataset.valid_y, reasoning_result=reasoning_result | |||
| ) | |||
| test_data = tab_data_to_tuple( | |||
| dataset.test_X, dataset.test_y, reasoning_result=reasoning_result | |||
| ) | |||
| return label_data, train_data, valid_data, test_data | |||
| def convert_lambdalearn_to_listdata( | |||
| self, | |||
| dataset: TabularMixin, | |||
| reasoning_result: Any | |||
| self, dataset: TabularMixin, reasoning_result: Any | |||
| ) -> Tuple[ListData, ListData, ListData, ListData]: | |||
| ''' | |||
| Convert a lambdalearn dataset to a tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples). | |||
| """ | |||
| Convert a lambdalearn dataset to a tuple of ListData | |||
| (label_data_examples, train_data_examples, valid_data_examples, test_data_examples). | |||
| Parameters | |||
| ---------- | |||
| dataset : TabularMixin | |||
| @@ -58,14 +70,20 @@ class DataConverter: | |||
| Returns | |||
| ------- | |||
| Tuple[ListData, ListData, ListData, ListData] | |||
| A tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples) | |||
| ''' | |||
| A tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples) # noqa: E501 | |||
| """ | |||
| if not isinstance(dataset, TabularMixin): | |||
| raise NotImplementedError("Only support converting the datasets that are instances of TabularMixin. Please refer to the documentation and manually convert the dataset into a ListData. ") | |||
| label_data, train_data, valid_data, test_data = self.convert_lambdalearn_to_tuple(dataset, reasoning_result) | |||
| raise NotImplementedError( | |||
| "Only support converting the datasets that are instances of TabularMixin. " | |||
| + "Please refer to the documentation and manually convert the dataset " | |||
| + "into a ListData." | |||
| ) | |||
| label_data, train_data, valid_data, test_data = self.convert_lambdalearn_to_tuple( | |||
| dataset, reasoning_result | |||
| ) | |||
| if label_data is not None: | |||
| X, gt_pseudo_label, Y = label_data | |||
| label_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
| @@ -78,23 +96,46 @@ class DataConverter: | |||
| if test_data is not None: | |||
| X, gt_pseudo_label, Y = test_data | |||
| test_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
| return label_data_examples, train_data_examples, valid_data_examples, test_data_examples | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| from lambdaLearn.Dataset.Tabular.BreastCancer import BreastCancer | |||
| breast_dataset=BreastCancer(labeled_size=0.1, stratified=True, shuffle=True) | |||
| breast_dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True) | |||
| dataconverter = DataConverter() | |||
| label_data, train_data, valid_data, test_data = dataconverter.convert_lambdalearn_to_tuple(breast_dataset, 0) | |||
| print(type(label_data).__name__, type(train_data).__name__, type(valid_data).__name__, type(test_data).__name__) | |||
| label_data, train_data, valid_data, test_data = dataconverter.convert_lambdalearn_to_tuple( | |||
| breast_dataset, 0 | |||
| ) | |||
| print( | |||
| type(label_data).__name__, | |||
| type(train_data).__name__, | |||
| type(valid_data).__name__, | |||
| type(test_data).__name__, | |||
| ) | |||
| print(len(label_data)) | |||
| print(len(label_data[0]), len(label_data[1]), len(label_data[2])) | |||
| print(label_data[0][0], label_data[1][0], label_data[2][0]) | |||
| print() | |||
| label_data_examples, train_data_examples, valid_data_examples, test_data_examples = dataconverter.convert_lambdalearn_to_listdata(breast_dataset, 0) | |||
| print(type(label_data_examples).__name__, type(train_data_examples).__name__, type(valid_data_examples).__name__, type(test_data_examples).__name__) | |||
| print(len(label_data_examples.X), len(label_data_examples.gt_pseudo_label), len(label_data_examples.Y)) | |||
| ( | |||
| label_data_examples, | |||
| train_data_examples, | |||
| valid_data_examples, | |||
| test_data_examples, | |||
| ) = dataconverter.convert_lambdalearn_to_listdata(breast_dataset, 0) | |||
| print( | |||
| type(label_data_examples).__name__, | |||
| type(train_data_examples).__name__, | |||
| type(valid_data_examples).__name__, | |||
| type(test_data_examples).__name__, | |||
| ) | |||
| print( | |||
| len(label_data_examples.X), | |||
| len(label_data_examples.gt_pseudo_label), | |||
| len(label_data_examples.Y), | |||
| ) | |||
| label_data_example = label_data_examples[0] | |||
| print(label_data_example.X, label_data_example.gt_pseudo_label, label_data_example.Y) | |||
| print(label_data_example.X, label_data_example.gt_pseudo_label, label_data_example.Y) | |||
| @@ -38,7 +38,8 @@ class ReasoningMetric(BaseMetric): | |||
| """ | |||
| Process a batch of data examples. | |||
| This method takes in a batch of data examples, each containing predicted pseudo-labels(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It | |||
| This method takes in a batch of data examples, each containing predicted pseudo-labels | |||
| (pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It | |||
| evaluates the reasoning accuracy of each example by comparing the logical reasoning | |||
| result (derived using the knowledge base) of the predicted pseudo-labels against Y | |||
| The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended | |||
| @@ -53,7 +53,7 @@ class ListData(BaseDataElement): | |||
| ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``. | |||
| This design is inspired by and extends the functionalities of the ``BaseDataElement`` | |||
| class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. | |||
| class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 | |||
| Examples: | |||
| >>> from abl.data.structures import ListData | |||
| @@ -71,7 +71,7 @@ class ListData(BaseDataElement): | |||
| DATA FIELDS | |||
| Y: [1, 2, 3] | |||
| gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] | |||
| X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] | |||
| X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 | |||
| ) at 0x7f3bbf1991c0> | |||
| >>> print(data_examples[:1]) | |||
| <ListData( | |||
| @@ -81,7 +81,8 @@ class BasicNN: | |||
| if not isinstance(device, torch.device): | |||
| if not isinstance(device, str): | |||
| raise TypeError( | |||
| "device must be an instance of torch.device or a str indicates the target device" | |||
| "device must be an instance of torch.device or a str indicating " | |||
| + "the target device" | |||
| ) | |||
| else: | |||
| device = torch.device(device) | |||
| @@ -163,9 +164,9 @@ class BasicNN: | |||
| return self | |||
| def fit( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| y: Optional[List[int]] = None, | |||
| ) -> BasicNN: | |||
| """ | |||
| @@ -271,8 +272,8 @@ class BasicNN: | |||
| return torch.cat(results, axis=0) | |||
| def predict( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| ) -> numpy.ndarray: | |||
| """ | |||
| @@ -312,8 +313,8 @@ class BasicNN: | |||
| return self._predict(data_loader).argmax(axis=1).cpu().numpy() | |||
| def predict_proba( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| ) -> numpy.ndarray: | |||
| """ | |||
| @@ -403,9 +404,9 @@ class BasicNN: | |||
| return mean_loss, accuracy | |||
| def score( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| y: Optional[List[int]] = None, | |||
| ) -> float: | |||
| """ | |||
| @@ -447,8 +448,8 @@ class BasicNN: | |||
| def _data_loader( | |||
| self, | |||
| X: Optional[List[Any]], | |||
| y: Optional[List[int]] = None, | |||
| X: Optional[List[Any]], | |||
| y: Optional[List[int]] = None, | |||
| shuffle: Optional[bool] = True, | |||
| ) -> DataLoader: | |||
| """ | |||
| @@ -6,16 +6,18 @@ from .abl_model import ABLModel | |||
| from .basic_nn import BasicNN | |||
| from lambdaLearn.Base.DeepModelMixin import DeepModelMixin | |||
| class ModelConverter: | |||
| ''' | |||
| """ | |||
| This class provides functionality to convert LambdaLearn models to ABL-Package models. | |||
| ''' | |||
| """ | |||
| def __init__(self) -> None: | |||
| pass | |||
| def convert_lambdalearn_to_ablmodel( | |||
| self, | |||
| lambdalearn_model, | |||
| lambdalearn_model, | |||
| loss_fn: torch.nn.Module, | |||
| optimizer_dict: dict, | |||
| scheduler_dict: Optional[dict] = None, | |||
| @@ -28,11 +30,13 @@ class ModelConverter: | |||
| save_dir: Optional[str] = None, | |||
| train_transform: Callable[..., Any] = None, | |||
| test_transform: Callable[..., Any] = None, | |||
| collate_fn: Callable[[List[Any]], Any] = None | |||
| collate_fn: Callable[[List[Any]], Any] = None, | |||
| ): | |||
| ''' | |||
| Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn model should implement fit and predict methods. | |||
| """ | |||
| Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of | |||
| DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn | |||
| model should implement ``fit`` and ``predict`` methods. | |||
| Parameters | |||
| ---------- | |||
| lambdalearn_model : Union[DeepModelMixin, Any] | |||
| @@ -75,17 +79,34 @@ class ModelConverter: | |||
| ------- | |||
| ABLModel | |||
| The converted ABLModel instance. | |||
| ''' | |||
| """ | |||
| if isinstance(lambdalearn_model, DeepModelMixin): | |||
| base_model = self.convert_lambdalearn_to_basicnn(lambdalearn_model, loss_fn, optimizer_dict, scheduler_dict, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn) | |||
| base_model = self.convert_lambdalearn_to_basicnn( | |||
| lambdalearn_model, | |||
| loss_fn, | |||
| optimizer_dict, | |||
| scheduler_dict, | |||
| device, | |||
| batch_size, | |||
| num_epochs, | |||
| stop_loss, | |||
| num_workers, | |||
| save_interval, | |||
| save_dir, | |||
| train_transform, | |||
| test_transform, | |||
| collate_fn, | |||
| ) | |||
| return ABLModel(base_model) | |||
| if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")): | |||
| raise NotImplementedError("The lambdalearn_model should be an instance of DeepModelMixin, or implement fit and predict methods.") | |||
| raise NotImplementedError( | |||
| "The lambdalearn_model should be an instance of DeepModelMixin, or implement " | |||
| + "fit and predict methods." | |||
| ) | |||
| return ABLModel(lambdalearn_model) | |||
| def convert_lambdalearn_to_basicnn( | |||
| self, | |||
| lambdalearn_model: DeepModelMixin, | |||
| @@ -103,9 +124,10 @@ class ModelConverter: | |||
| test_transform: Callable[..., Any] = None, | |||
| collate_fn: Callable[[List[Any]], Any] = None, | |||
| ): | |||
| ''' | |||
| Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN. | |||
| """ | |||
| Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of | |||
| DeepModelMixin, its network will be used as the model of BasicNN. | |||
| Parameters | |||
| ---------- | |||
| lambdalearn_model : Union[DeepModelMixin, Any] | |||
| @@ -147,10 +169,13 @@ class ModelConverter: | |||
| ------- | |||
| BasicNN | |||
| The converted BasicNN instance. | |||
| ''' | |||
| """ | |||
| if isinstance(lambdalearn_model, DeepModelMixin): | |||
| if not isinstance(lambdalearn_model.network, torch.nn.Module): | |||
| raise NotImplementedError(f"Expected lambdalearn_model.network to be a torch.nn.Module, but got {type(lambdalearn_model.network)}") | |||
| raise NotImplementedError( | |||
| "Expected lambdalearn_model.network to be a torch.nn.Module, " | |||
| + f"but got {type(lambdalearn_model.network)}" | |||
| ) | |||
| # Only use the network part and device of the lambdalearn model | |||
| network = copy.deepcopy(lambdalearn_model.network) | |||
| optimizer_class = optimizer_dict["optimizer"] | |||
| @@ -163,7 +188,24 @@ class ModelConverter: | |||
| else: | |||
| scheduler = None | |||
| device = lambdalearn_model.device if device is None else device | |||
| base_model = BasicNN(model=network, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, device=device, batch_size=batch_size, num_epochs=num_epochs, stop_loss=stop_loss, num_workers=num_workers, save_interval=save_interval, save_dir=save_dir, train_transform=train_transform, test_transform=test_transform, collate_fn=collate_fn) | |||
| base_model = BasicNN( | |||
| model=network, | |||
| loss_fn=loss_fn, | |||
| optimizer=optimizer, | |||
| scheduler=scheduler, | |||
| device=device, | |||
| batch_size=batch_size, | |||
| num_epochs=num_epochs, | |||
| stop_loss=stop_loss, | |||
| num_workers=num_workers, | |||
| save_interval=save_interval, | |||
| save_dir=save_dir, | |||
| train_transform=train_transform, | |||
| test_transform=test_transform, | |||
| collate_fn=collate_fn, | |||
| ) | |||
| return base_model | |||
| else: | |||
| raise NotImplementedError("The lambdalearn_model should be an instance of DeepModelMixin.") | |||
| raise NotImplementedError( | |||
| "The lambdalearn_model should be an instance of DeepModelMixin." | |||
| ) | |||
| @@ -26,7 +26,7 @@ class KBBase(ABC): | |||
| list so that each aligns with its corresponding index in the base model: the first with | |||
| the 0th index, the second with the 1st, and so forth. | |||
| max_err : float, optional | |||
| The upper tolerance limit when comparing the similarity between the reasoning result of | |||
| The upper tolerance limit when comparing the similarity between the reasoning result of | |||
| pseudo-labels and the ground truth. This is only applicable when the reasoning | |||
| result is of a numerical type. This is particularly relevant for regression problems where | |||
| exact matches might not be feasible. Defaults to 1e-10. | |||
| @@ -65,10 +65,12 @@ class KBBase(ABC): | |||
| self.use_cache = use_cache | |||
| self.key_func = key_func | |||
| self.cache_size = cache_size | |||
| argspec = inspect.getfullargspec(self.logic_forward) | |||
| self._num_args = len(argspec.args) - 1 | |||
| if self._num_args==2 and self.use_cache: # If the logic_forward function has 2 arguments, then disable cache | |||
| if ( | |||
| self._num_args == 2 and self.use_cache | |||
| ): # If the logic_forward function has 2 arguments, then disable cache | |||
| self.use_cache = False | |||
| print_log( | |||
| "The logic_forward function has 2 arguments, so the cache is disabled. ", | |||
| @@ -89,10 +91,10 @@ class KBBase(ABC): | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example. | |||
| x : List[Any], optional | |||
| The example. If deductive logical reasoning does not require any | |||
| information from the example, the overridden function provided by the user can omit | |||
| The example. If deductive logical reasoning does not require any | |||
| information from the example, the overridden function provided by the user can omit | |||
| this parameter. | |||
| Returns | |||
| ------- | |||
| Any | |||
| @@ -100,11 +102,11 @@ class KBBase(ABC): | |||
| """ | |||
| def abduce_candidates( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| require_more_revision: int, | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| @@ -118,7 +120,7 @@ class KBBase(ABC): | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The example. If the information from the example | |||
| is not required in the reasoning process, then this parameter will not have | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| max_revision_num : int | |||
| The upper limit on the number of revised labels for each example. | |||
| @@ -129,9 +131,9 @@ class KBBase(ABC): | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example. that are compatible with the knowledge base. The second element is | |||
| a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
| logic_forward function. | |||
| pseudo-labels of the example. that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision) | |||
| @@ -154,10 +156,10 @@ class KBBase(ABC): | |||
| return reasoning_result == y | |||
| def revise_at_idx( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| @@ -171,7 +173,7 @@ class KBBase(ABC): | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The example. If the information from the example | |||
| is not required in the reasoning process, then this parameter will not have | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| revision_idx : List[int] | |||
| A list specifying indices of where revisions should be made to the pseudo-labels. | |||
| @@ -180,9 +182,9 @@ class KBBase(ABC): | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second element is | |||
| a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
| logic_forward function. | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| candidates, reasoning_results = [], [] | |||
| abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | |||
| @@ -192,14 +194,15 @@ class KBBase(ABC): | |||
| candidate[idx] = c[i] | |||
| reasoning_result = self.logic_forward(candidate, *(x,) if self._num_args == 2 else ()) | |||
| if self._check_equal(reasoning_result, y): | |||
| candidates.append(candidate); reasoning_results.append(reasoning_result) | |||
| candidates.append(candidate) | |||
| reasoning_results.append(reasoning_result) | |||
| return candidates, reasoning_results | |||
| def _revision( | |||
| self, | |||
| revision_num: int, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| self, | |||
| revision_num: int, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| @@ -210,16 +213,17 @@ class KBBase(ABC): | |||
| revision_idx_list = combinations(range(len(pseudo_label)), revision_num) | |||
| for revision_idx in revision_idx_list: | |||
| candidates, reasoning_results = self.revise_at_idx(pseudo_label, y, x, revision_idx) | |||
| new_candidates.extend(candidates); new_reasoning_results.extend(reasoning_results) | |||
| new_candidates.extend(candidates) | |||
| new_reasoning_results.extend(reasoning_results) | |||
| return new_candidates, new_reasoning_results | |||
| @abl_cache() | |||
| def _abduce_by_search( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| require_more_revision: int, | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| @@ -235,7 +239,7 @@ class KBBase(ABC): | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The example. If the information from the example | |||
| is not required in the reasoning process, then this parameter will not have | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| max_revision_num : int | |||
| The upper limit on the number of revisions. | |||
| @@ -248,14 +252,15 @@ class KBBase(ABC): | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second element is | |||
| a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
| logic_forward function. | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| candidates, reasoning_results = [], [] | |||
| for revision_num in range(len(pseudo_label) + 1): | |||
| new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x) | |||
| candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results) | |||
| candidates.extend(new_candidates) | |||
| reasoning_results.extend(new_reasoning_results) | |||
| if len(candidates) > 0: | |||
| min_revision_num = revision_num | |||
| break | |||
| @@ -268,7 +273,8 @@ class KBBase(ABC): | |||
| if revision_num > max_revision_num: | |||
| return candidates, reasoning_results | |||
| new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x) | |||
| candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results) | |||
| candidates.extend(new_candidates) | |||
| reasoning_results.extend(new_reasoning_results) | |||
| return candidates, reasoning_results | |||
| def __repr__(self): | |||
| @@ -305,16 +311,19 @@ class GroundKB(KBBase): | |||
| """ | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list: List[Any], | |||
| GKB_len_list: List[int], | |||
| self, | |||
| pseudo_label_list: List[Any], | |||
| GKB_len_list: List[int], | |||
| max_err: float = 1e-10, | |||
| ): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| if not isinstance(GKB_len_list, list): | |||
| raise TypeError("GKB_len_list should be list, but got {type(GKB_len_list)}") | |||
| if self._num_args==2: | |||
| raise NotImplementedError(f"GroundKB only supports 1-argument logic_forward, but got {self._num_args}-argument logic_forward") | |||
| if self._num_args == 2: | |||
| raise NotImplementedError( | |||
| "GroundKB only supports 1-argument logic_forward, but got " | |||
| + f"{self._num_args}-argument logic_forward" | |||
| ) | |||
| self.GKB_len_list = GKB_len_list | |||
| self.GKB = {} | |||
| X, Y = self._get_GKB() | |||
| @@ -354,11 +363,11 @@ class GroundKB(KBBase): | |||
| return X, Y | |||
| def abduce_candidates( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| require_more_revision: int, | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| @@ -383,9 +392,9 @@ class GroundKB(KBBase): | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of THE example that are compatible with the knowledge base. The second element is | |||
| a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
| logic_forward function. | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list: | |||
| return [], [] | |||
| @@ -418,7 +427,8 @@ class GroundKB(KBBase): | |||
| all_candidates, all_reasoning_results = [], [] | |||
| for key in key_list[low_key:high_key]: | |||
| for candidate in potential_candidates[key]: | |||
| all_candidates.append(candidate); all_reasoning_results.append(key) | |||
| all_candidates.append(candidate) | |||
| all_reasoning_results.append(key) | |||
| else: | |||
| all_candidates = self.GKB[len(pseudo_label)][y] | |||
| all_reasoning_results = [y] * len(all_candidates) | |||
| @@ -468,14 +478,17 @@ class PrologKB(KBBase): | |||
| def __init__(self, pseudo_label_list: List[Any], pl_file: str): | |||
| super().__init__(pseudo_label_list) | |||
| try: | |||
| import pyswip | |||
| except (IndexError, ImportError): | |||
| print("A Prolog-based knowledge base is in use. Please install Swi-Prolog \ | |||
| using the command 'sudo apt-get install swi-prolog' for Linux users, \ | |||
| or download it following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md for Windows and Mac users.") | |||
| print( | |||
| "A Prolog-based knowledge base is in use. Please install Swi-Prolog using the" | |||
| + "command 'sudo apt-get install swi-prolog' for Linux users, or download it " | |||
| + "following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md " | |||
| + "for Windows and Mac users." | |||
| ) | |||
| self.prolog = pyswip.Prolog() | |||
| self.pl_file = pl_file | |||
| if not os.path.exists(self.pl_file): | |||
| @@ -519,9 +532,9 @@ class PrologKB(KBBase): | |||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label)) | |||
| def get_query_string( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> str: | |||
| @@ -538,8 +551,8 @@ class PrologKB(KBBase): | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| The corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| revision_idx : List[int] | |||
| A list specifying indices of where revisions should be made to the pseudo-labels. | |||
| @@ -556,10 +569,10 @@ class PrologKB(KBBase): | |||
| return query_string | |||
| def revise_at_idx( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| @@ -572,8 +585,8 @@ class PrologKB(KBBase): | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| The corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| revision_idx : List[int] | |||
| A list specifying indices of where revisions should be made to the pseudo-labels. | |||
| @@ -581,12 +594,10 @@ class PrologKB(KBBase): | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A list of candidates, i.e. revised pseudo-labels of the example that are compatible with the | |||
| knowledge base. | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second element is | |||
| a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
| logic_forward function. | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| candidates, reasoning_results = [], [] | |||
| query_string = self.get_query_string(pseudo_label, y, x, revision_idx) | |||
| @@ -598,7 +609,8 @@ class PrologKB(KBBase): | |||
| for i, idx in enumerate(revision_idx): | |||
| candidate[idx] = c[i] | |||
| candidate = reform_list(candidate, save_pseudo_label) | |||
| candidates.append(candidate); reasoning_results.append(y) | |||
| candidates.append(candidate) | |||
| reasoning_results.append(y) | |||
| return candidates, reasoning_results | |||
| def __repr__(self): | |||
| @@ -28,10 +28,10 @@ class Reasoner: | |||
| candidate, 'confidence': calculates the distance between the prediction | |||
| and each candidate based on confidence derived from the predicted probability | |||
| in the data example. The callable function should have the signature | |||
| dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element | |||
| in this cost list should be a numerical value representing the cost for each | |||
| candidate, and the list should have the same length as candidates. | |||
| Defaults to 'confidence'. | |||
| dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must | |||
| return a cost list. Each element in this cost list should be a numerical value | |||
| representing the cost for each candidate, and the list should have the same length | |||
| as candidates. Defaults to 'confidence'. | |||
| idx_to_label : dict, optional | |||
| A mapping from index in the base model to label. If not provided, a default | |||
| order-based index to label mapping is created. Defaults to None. | |||
| @@ -76,14 +76,16 @@ class Reasoner: | |||
| if isinstance(dist_func, str): | |||
| if dist_func not in ["hamming", "confidence"]: | |||
| raise NotImplementedError( | |||
| f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.' | |||
| 'Valid options for predefined dist_func include "hamming" ' | |||
| + f'and "confidence", but got {dist_func}.' | |||
| ) | |||
| return | |||
| elif callable(dist_func): | |||
| params = inspect.signature(dist_func).parameters.values() | |||
| if len(params) != 4: | |||
| raise ValueError( | |||
| f"User-defined dist_func must have exactly four parameters, but got {len(params)}." | |||
| "User-defined dist_func must have exactly four parameters, " | |||
| + f"but got {len(params)}." | |||
| ) | |||
| return | |||
| else: | |||
| @@ -99,7 +101,8 @@ class Reasoner: | |||
| raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.") | |||
| if value not in self.kb.pseudo_label_list: | |||
| raise ValueError( | |||
| f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}." | |||
| "All values in the idx_to_label must be in the pseudo_label_list, " | |||
| + f"but got {value}." | |||
| ) | |||
| def _get_one_candidate( | |||
| @@ -169,8 +172,8 @@ class Reasoner: | |||
| cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) | |||
| if len(cost_list) != len(candidates): | |||
| raise ValueError( | |||
| f"The length of the array returned by dist_func must be equal to the number of candidates. " | |||
| f"Expected length {len(candidates)}, but got {len(cost_list)}." | |||
| "The length of the array returned by dist_func must be equal to the number " | |||
| + f"of candidates. Expected length {len(candidates)}, but got {len(cost_list)}." | |||
| ) | |||
| return cost_list | |||
| @@ -204,7 +207,9 @@ class Reasoner: | |||
| dim=dimension, | |||
| constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
| ) | |||
| parameter = Parameter(budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True) | |||
| parameter = Parameter( | |||
| budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True | |||
| ) | |||
| solution = Opt.min(objective, parameter) | |||
| return solution | |||
| @@ -240,29 +245,28 @@ class Reasoner: | |||
| return np.min(self._get_cost_list(data_example, candidates, reasoning_results)) | |||
| else: | |||
| return symbol_num | |||
| def zoopt_budget(self, symbol_num: int) -> int: | |||
| """ | |||
| Set the budget for ZOOpt optimization. The function, in its default implementation, | |||
| returns a fixed budget value of 100. However, it can be adjusted to return other fixed | |||
| values, or a dynamic budget based on the number of symbols, if desired. For example, one might choose to | |||
| set the budget as 100 times symbol_num. | |||
| Set the budget for ZOOpt optimization. The function, in its default implementation, | |||
| returns a fixed budget value of 100. However, it can be adjusted to return other fixed | |||
| values, or a dynamic budget based on the number of symbols, if desired. For example, | |||
| one might choose to set the budget as 100 times ``symbol_num``. | |||
| Parameters | |||
| ---------- | |||
| symbol_num : int | |||
| The number of symbols to be considered in the ZOOpt optimization process. Although this parameter | |||
| can be used to compute a dynamic optimization budget, by default it is not utilized in the | |||
| calculation. | |||
| The number of symbols to be considered in the ZOOpt optimization process. Although this | |||
| parameter can be used to compute a dynamic optimization budget, by default it is not | |||
| utilized in the calculation. | |||
| Returns | |||
| ------- | |||
| int | |||
| The budget for ZOOpt optimization. By default, this is a fixed value of 100, | |||
| The budget for ZOOpt optimization. By default, this is a fixed value of 100, | |||
| irrespective of the symbol_num value. | |||
| """ | |||
| return 100 | |||
| def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: | |||
| """ | |||
| @@ -284,7 +288,8 @@ class Reasoner: | |||
| elif isinstance(max_revision, float): | |||
| if not (0 <= max_revision <= 1): | |||
| raise ValueError( | |||
| f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}" | |||
| "If max_revision is a float, it must be between 0 and 1, " | |||
| + f"but got {max_revision}" | |||
| ) | |||
| return round(symbol_num * max_revision) | |||
| else: | |||
| @@ -1,6 +1,13 @@ | |||
| from .cache import Cache, abl_cache | |||
| from .logger import ABLLogger, print_log | |||
| from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple | |||
| from .utils import ( | |||
| confidence_dist, | |||
| flatten, | |||
| hamming_dist, | |||
| reform_list, | |||
| to_hashable, | |||
| tab_data_to_tuple, | |||
| ) | |||
| __all__ = [ | |||
| "Cache", | |||
| @@ -43,7 +43,7 @@ class Cache(Generic[K, T]): | |||
| def get_from_dict(self, obj, *args) -> T: | |||
| """Implements dict based cache.""" | |||
| # x is not used in cache key | |||
| pred_pseudo_label, y, x, *res_args = args | |||
| pred_pseudo_label, y, x, *res_args = args | |||
| cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) | |||
| link = self.cache_dict.get(cache_key) | |||
| if link is not None: | |||
| @@ -168,9 +168,11 @@ class ABLLogger(Logger, ManagerMixin): | |||
| Notes | |||
| ----- | |||
| - The ``name`` of the logger and the ``instance_name`` of ``ABLLogger`` could be different. | |||
| ``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not ``logging.getLogger``. | |||
| This ensures ``ABLLogger`` is not influenced by third-party logging configurations. | |||
| - Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without ``Handler``. | |||
| ``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not | |||
| ``logging.getLogger``. This ensures ``ABLLogger`` is not influenced by third-party logging | |||
| configurations. | |||
| - Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without | |||
| ``Handler``. | |||
| Examples | |||
| -------- | |||
| @@ -288,15 +290,16 @@ class ABLLogger(Logger, ManagerMixin): | |||
| def print_log( | |||
| msg, | |||
| logger: Optional[Union[Logger, str]] = None, | |||
| msg, | |||
| logger: Optional[Union[Logger, str]] = None, | |||
| level: Optional[int] = logging.INFO, | |||
| ) -> None: | |||
| """ | |||
| Print a log message using the specified logger or a default method. | |||
| This function logs a message with a given logger, if provided, or prints it using | |||
| the standard ``print`` function. It supports special logger types such as 'silent' and 'current'. | |||
| the standard ``print`` function. It supports special logger types such as 'silent' | |||
| and 'current'. | |||
| Parameters | |||
| ---------- | |||
| @@ -308,8 +311,8 @@ def print_log( | |||
| method is used. | |||
| - 'silent': No message will be printed. | |||
| - 'current': Use the latest created logger to log the message. | |||
| - other str: The instance name of the logger. A ``ValueError`` is raised if the logger has not | |||
| been created. | |||
| - other str: The instance name of the logger. A ``ValueError`` is raised if the logger has | |||
| not been created. | |||
| - None: The ``print()`` method is used for logging. | |||
| level : int, optional | |||
| The logging level. This is only applicable when ``logger`` is a Logger object, 'current', | |||
| @@ -15,7 +15,7 @@ def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[A | |||
| Returns | |||
| ------- | |||
| List[Any] | |||
| A flattened version of the input list, where only the first | |||
| A flattened version of the input list, where only the first | |||
| level of sublists and tuples are reduced. | |||
| """ | |||
| if not isinstance(nested_list, list): | |||
| @@ -24,15 +24,15 @@ def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[A | |||
| flattened_list = [] | |||
| for item in nested_list: | |||
| if isinstance(item, (list, tuple)): | |||
| flattened_list.extend(item) | |||
| flattened_list.extend(item) | |||
| else: | |||
| flattened_list.append(item) | |||
| return flattened_list | |||
| def reform_list( | |||
| flattened_list: List[Any], | |||
| structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]] | |||
| flattened_list: List[Any], structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]] | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| Reform the list based on the structure of ``structured_list``. | |||
| @@ -148,16 +148,15 @@ def restore_from_hashable(x): | |||
| return [restore_from_hashable(item) for item in x] | |||
| return x | |||
| def tab_data_to_tuple( | |||
| X: Union[List[Any], Any], | |||
| y: Union[List[Any], Any], | |||
| reasoning_result: Optional[Any] = 0 | |||
| X: Union[List[Any], Any], y: Union[List[Any], Any], reasoning_result: Optional[Any] = 0 | |||
| ) -> Tuple[List[List[Any]], List[List[Any]], List[Any]]: | |||
| ''' | |||
| Convert a tabular data to a tuple by adding a dimension to each element of | |||
| """ | |||
| Convert a tabular data to a tuple by adding a dimension to each element of | |||
| X and y. The tuple contains three elements: data, label, and reasoning result. | |||
| If X is None, return None. | |||
| Parameters | |||
| ---------- | |||
| X : Union[List[Any], Any] | |||
| @@ -166,14 +165,16 @@ def tab_data_to_tuple( | |||
| The label. | |||
| reasoning_result : Any, optional | |||
| The reasoning result, by default 0. | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[List[Any]], List[Any]] | |||
| A tuple of (data, label, reasoning_result). | |||
| ''' | |||
| """ | |||
| if X is None: | |||
| return None | |||
| if len(X) != len(y): | |||
| raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) | |||
| return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) | |||
| raise ValueError( | |||
| "The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)) | |||
| ) | |||
| return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) | |||
| @@ -5,27 +5,28 @@ import sys | |||
| from docutils import nodes | |||
| from docutils.parsers.rst import roles | |||
| import re | |||
| from sphinx.application import Sphinx | |||
| def remove_noqa(app: Sphinx, what: str, name: str, obj, options, lines): | |||
| new_lines = [] | |||
| for line in lines: | |||
| new_line = re.sub(r'\s*#\s*noqa.*$', '', line) | |||
| new_line = re.sub(r"\s*#\s*noqa.*$", "", line) | |||
| new_lines.append(new_line) | |||
| lines[:] = new_lines | |||
| def colored_text_role(role, rawtext, text, lineno, inliner, options={}, content=[]): | |||
| node = nodes.inline(rawtext, text, classes=[role]) | |||
| return [node], [] | |||
| roles.register_local_role('green-bold', colored_text_role) | |||
| roles.register_local_role('blue-bold', colored_text_role) | |||
| roles.register_local_role('yellow-bold', colored_text_role) | |||
| roles.register_local_role('green', colored_text_role) | |||
| roles.register_local_role('blue', colored_text_role) | |||
| roles.register_local_role('yellow', colored_text_role) | |||
| roles.register_local_role("green-bold", colored_text_role) | |||
| roles.register_local_role("blue-bold", colored_text_role) | |||
| roles.register_local_role("yellow-bold", colored_text_role) | |||
| roles.register_local_role("green", colored_text_role) | |||
| roles.register_local_role("blue", colored_text_role) | |||
| roles.register_local_role("yellow", colored_text_role) | |||
| if "READTHEDOCS" not in os.environ: | |||
| @@ -45,7 +46,7 @@ author = "Author" | |||
| extensions = [ | |||
| "sphinx.ext.intersphinx", | |||
| "sphinx.ext.autodoc", | |||
| 'sphinx.ext.autosummary', | |||
| "sphinx.ext.autosummary", | |||
| "sphinx.ext.mathjax", | |||
| "sphinx.ext.viewcode", | |||
| "sphinx_rtd_theme", | |||
| @@ -95,7 +96,8 @@ texinfo_documents = [ | |||
| def setup(app): | |||
| from sphinx.domains.python import PyField | |||
| from sphinx.util.docfields import Field | |||
| app.connect('autodoc-process-docstring', remove_noqa) | |||
| app.connect("autodoc-process-docstring", remove_noqa) | |||
| app.add_object_type( | |||
| "confval", | |||
| "confval", | |||
| @@ -247,7 +247,7 @@ class HedBridge(SimpleBridge): | |||
| logger="current", | |||
| ) | |||
| self.model.load( | |||
| load_path=os.path.join(save_dir, f"pretrain_weights.pth") | |||
| load_path=os.path.join(save_dir, "pretrain_weights.pth") | |||
| ) | |||
| else: | |||
| self.model.load( | |||
| @@ -12,16 +12,20 @@ from torchvision.transforms import transforms | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| def download_and_unzip(url, zip_file_name): | |||
| try: | |||
| gdown.download(url, zip_file_name) | |||
| with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: | |||
| with zipfile.ZipFile(zip_file_name, "r") as zip_ref: | |||
| zip_ref.extractall(CURRENT_DIR) | |||
| os.remove(zip_file_name) | |||
| except Exception as e: | |||
| if os.path.exists(zip_file_name): | |||
| os.remove(zip_file_name) | |||
| raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hed/datasets' folder") | |||
| raise Exception( | |||
| f"An error occurred during download or unzip: {e}. Instead, you can download " | |||
| + f"the dataset from {url} and unzip it in 'examples/hed/datasets' folder" | |||
| ) | |||
| def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||
| @@ -33,7 +37,7 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||
| img_path_list = os.listdir(label_path) | |||
| for img_path in img_path_list: | |||
| with Image.open(osp.join(label_path, img_path)) as img: | |||
| img = img.convert('L') | |||
| img = img.convert("L") | |||
| img = img.resize((image_size[1], image_size[0])) | |||
| img_array = np.array(img, dtype=np.float32) | |||
| normalized_img = (img_array - 127) / 128.0 | |||
| @@ -72,19 +76,19 @@ def split_equation(equations_by_len, prop_train, prop_val): | |||
| def get_dataset(dataset="mnist", train=True): | |||
| data_dir = CURRENT_DIR + '/mnist_images' | |||
| data_dir = CURRENT_DIR + "/mnist_images" | |||
| if not os.path.exists(data_dir): | |||
| print("Dataset not exist, downloading it...") | |||
| url = 'https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download' | |||
| url = "https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download" | |||
| download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip")) | |||
| print("Download and extraction complete.") | |||
| if train: | |||
| file = os.path.join(data_dir, "expr_train.json") | |||
| else: | |||
| file = os.path.join(data_dir, "expr_test.json") | |||
| if dataset == "mnist": | |||
| file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | |||
| elif dataset == "random": | |||
| @@ -94,7 +98,7 @@ def get_dataset(dataset="mnist", train=True): | |||
| with open(file, "rb") as f: | |||
| img_dataset = pickle.load(f) | |||
| X, Y = [], [] | |||
| if train: | |||
| positive = img_dataset["train:positive"] | |||
| @@ -117,4 +121,3 @@ def get_dataset(dataset="mnist", train=True): | |||
| equations_by_len = divide_equations_by_len(X, Y) | |||
| return equations_by_len | |||
| @@ -86,31 +86,39 @@ | |||
| "source": [ | |||
| "true_train_equation = train_data[1]\n", | |||
| "false_train_equation = train_data[0]\n", | |||
| "print(f\"Equations in the dataset is organized by equation length, \" +\n", | |||
| " f\"from {min(train_data[0].keys())} to {max(train_data[0].keys())}\")\n", | |||
| "print(\n", | |||
| " f\"Equations in the dataset is organized by equation length, \"\n", | |||
| " + f\"from {min(train_data[0].keys())} to {max(train_data[0].keys())}\"\n", | |||
| ")\n", | |||
| "print()\n", | |||
| "\n", | |||
| "true_train_equation_with_length_5 = true_train_equation[5]\n", | |||
| "false_train_equation_with_length_5 = false_train_equation[5]\n", | |||
| "print(f\"For each euqation length, there are {len(true_train_equation_with_length_5)} \" +\n", | |||
| " f\"true equation and {len(false_train_equation_with_length_5)} false equation \" +\n", | |||
| " f\"in the training set\")\n", | |||
| "print(\n", | |||
| " f\"For each euqation length, there are {len(true_train_equation_with_length_5)} \"\n", | |||
| " + f\"true equation and {len(false_train_equation_with_length_5)} false equation \"\n", | |||
| " + f\"in the training set\"\n", | |||
| ")\n", | |||
| "\n", | |||
| "true_val_equation = val_data[1]\n", | |||
| "false_val_equation = val_data[0]\n", | |||
| "true_val_equation_with_length_5 = true_val_equation[5]\n", | |||
| "false_val_equation_with_length_5 = false_val_equation[5]\n", | |||
| "print(f\"For each euqation length, there are {len(true_val_equation_with_length_5)} \" +\n", | |||
| " f\"true equation and {len(false_val_equation_with_length_5)} false equation \" +\n", | |||
| " f\"in the validation set\")\n", | |||
| "print(\n", | |||
| " f\"For each euqation length, there are {len(true_val_equation_with_length_5)} \"\n", | |||
| " + f\"true equation and {len(false_val_equation_with_length_5)} false equation \"\n", | |||
| " + f\"in the validation set\"\n", | |||
| ")\n", | |||
| "\n", | |||
| "true_test_equation = test_data[1]\n", | |||
| "false_test_equation = test_data[0]\n", | |||
| "true_test_equation_with_length_5 = true_test_equation[5]\n", | |||
| "false_test_equation_with_length_5 = false_test_equation[5]\n", | |||
| "print(f\"For each euqation length, there are {len(true_test_equation_with_length_5)} \" +\n", | |||
| " f\"true equation and {len(false_test_equation_with_length_5)} false equation \" +\n", | |||
| " f\"in the test set\")" | |||
| "print(\n", | |||
| " f\"For each euqation length, there are {len(true_test_equation_with_length_5)} \"\n", | |||
| " + f\"true equation and {len(false_test_equation_with_length_5)} false equation \"\n", | |||
| " + f\"in the test set\"\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -199,30 +207,30 @@ | |||
| "true_train_equation_with_length_8 = true_train_equation[8]\n", | |||
| "print(f\"First true equation with length 5 in the training dataset:\")\n", | |||
| "for i, x in enumerate(true_train_equation_with_length_5[0]):\n", | |||
| " plt.subplot(1, 5, i+1)\n", | |||
| " plt.axis('off') \n", | |||
| " plt.imshow(x.squeeze(), cmap='gray')\n", | |||
| " plt.subplot(1, 5, i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "print(f\"First true equation with length 8 in the training dataset:\")\n", | |||
| "for i, x in enumerate(true_train_equation_with_length_8[0]):\n", | |||
| " plt.subplot(1, 8, i+1)\n", | |||
| " plt.axis('off') \n", | |||
| " plt.imshow(x.squeeze(), cmap='gray')\n", | |||
| " plt.subplot(1, 8, i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "\n", | |||
| "false_train_equation_with_length_5 = false_train_equation[5]\n", | |||
| "false_train_equation_with_length_8 = false_train_equation[8]\n", | |||
| "print(f\"First false equation with length 5 in the training dataset:\")\n", | |||
| "for i, x in enumerate(false_train_equation_with_length_5[0]):\n", | |||
| " plt.subplot(1, 5, i+1)\n", | |||
| " plt.axis('off') \n", | |||
| " plt.imshow(x.squeeze(), cmap='gray')\n", | |||
| " plt.subplot(1, 5, i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "print(f\"First false equation with length 8 in the training dataset:\")\n", | |||
| "for i, x in enumerate(false_train_equation_with_length_8[0]):\n", | |||
| " plt.subplot(1, 8, i+1)\n", | |||
| " plt.axis('off') \n", | |||
| " plt.imshow(x.squeeze(), cmap='gray')\n", | |||
| " plt.subplot(1, 8, i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()" | |||
| ] | |||
| }, | |||
| @@ -46,20 +46,20 @@ def main(): | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| total_train_data = get_dataset(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_dataset(train=False) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=4) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| @@ -83,7 +83,7 @@ def main(): | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| kb = HedKB() | |||
| @@ -1,3 +1,3 @@ | |||
| from .reasoning import HedKB, HedReasoner | |||
| __all__ = ["HedKB", "HedReasoner"] | |||
| __all__ = ["HedKB", "HedReasoner"] | |||
| @@ -8,8 +8,11 @@ from abl.utils import reform_list | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| class HedKB(PrologKB): | |||
| def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): | |||
| def __init__( | |||
| self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl") | |||
| ): | |||
| pl_file = pl_file.replace("\\", "/") | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| self.learned_rules = {} | |||
| @@ -34,7 +37,7 @@ class HedReasoner(Reasoner): | |||
| data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx | |||
| ) | |||
| return candidate | |||
| def zoopt_budget(self, symbol_num): | |||
| return 200 | |||
| @@ -53,7 +56,7 @@ class HedReasoner(Reasoner): | |||
| max_candidate_idxs = [] | |||
| found = False | |||
| for idx in range(-1, len(data_example.pred_idx)): | |||
| if (not idx in idxs) and (idx >= 0): | |||
| if (idx not in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidates, _ = self.revise_at_idx(data_example[idxs]) | |||
| if len(candidates) == 0: | |||
| @@ -96,4 +99,4 @@ class HedReasoner(Reasoner): | |||
| return abduced_pseudo_label | |||
| def abduce_rules(self, pred_res): | |||
| return self.kb.abduce_rules(pred_res) | |||
| return self.kb.abduce_rules(pred_res) | |||
| @@ -1,3 +1,3 @@ | |||
| from .get_dataset import get_dataset | |||
| __all__ = ["get_dataset"] | |||
| __all__ = ["get_dataset"] | |||
| @@ -10,26 +10,31 @@ CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) | |||
| def download_and_unzip(url, zip_file_name): | |||
| try: | |||
| gdown.download(url, zip_file_name) | |||
| with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: | |||
| with zipfile.ZipFile(zip_file_name, "r") as zip_ref: | |||
| zip_ref.extractall(CURRENT_DIR) | |||
| os.remove(zip_file_name) | |||
| except Exception as e: | |||
| if os.path.exists(zip_file_name): | |||
| os.remove(zip_file_name) | |||
| raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hwf/datasets' folder") | |||
| raise Exception( | |||
| f"An error occurred during download or unzip: {e}. Instead, you can download " | |||
| + f"the dataset from {url} and unzip it in 'examples/hwf/datasets' folder" | |||
| ) | |||
| def get_dataset(train=True, get_pseudo_label=False): | |||
| data_dir = CURRENT_DIR + '/data' | |||
| data_dir = CURRENT_DIR + "/data" | |||
| if not os.path.exists(data_dir): | |||
| print("Dataset not exist, downloading it...") | |||
| url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download' | |||
| url = "https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download" | |||
| download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip")) | |||
| print("Download and extraction complete.") | |||
| if train: | |||
| file = os.path.join(data_dir, "expr_train.json") | |||
| else: | |||
| @@ -59,4 +64,4 @@ def get_dataset(train=True, get_pseudo_label=False): | |||
| pseudo_label.append(imgs_pseudo_label) | |||
| Y.append(data[idx]["res"]) | |||
| return X, pseudo_label, Y | |||
| return X, pseudo_label, Y | |||
| @@ -72,21 +72,28 @@ | |||
| "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", | |||
| "print()\n", | |||
| "train_X, train_gt_pseudo_label, train_Y = train_data\n", | |||
| "print(f\"Length of X, gt_pseudo_label, Y in train_data: \" +\n", | |||
| " f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\")\n", | |||
| "print(\n", | |||
| " f\"Length of X, gt_pseudo_label, Y in train_data: \"\n", | |||
| " + f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\"\n", | |||
| ")\n", | |||
| "test_X, test_gt_pseudo_label, test_Y = test_data\n", | |||
| "print(f\"Length of X, gt_pseudo_label, Y in test_data: \" +\n", | |||
| " f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\")\n", | |||
| "print(\n", | |||
| " f\"Length of X, gt_pseudo_label, Y in test_data: \"\n", | |||
| " + f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\"\n", | |||
| ")\n", | |||
| "print()\n", | |||
| "\n", | |||
| "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n", | |||
| "print(f\"X is a {type(train_X).__name__}, \" +\n", | |||
| " f\"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.\")\n", | |||
| "print(f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \" +\n", | |||
| " f\"with each element being a {type(gt_pseudo_label_0).__name__} \" +\n", | |||
| " f\"of {type(gt_pseudo_label_0[0]).__name__}.\")\n", | |||
| "print(f\"Y is a {type(train_Y).__name__}, \" +\n", | |||
| " f\"with each element being a {type(Y_0).__name__}.\")" | |||
| "print(\n", | |||
| " f\"X is a {type(train_X).__name__}, \"\n", | |||
| " + f\"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.\"\n", | |||
| ")\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \"\n", | |||
| " + f\"with each element being a {type(gt_pseudo_label_0).__name__} \"\n", | |||
| " + f\"of {type(gt_pseudo_label_0[0]).__name__}.\"\n", | |||
| ")\n", | |||
| "print(f\"Y is a {type(train_Y).__name__}, \" + f\"with each element being a {type(Y_0).__name__}.\")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -105,21 +112,25 @@ | |||
| "X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000]\n", | |||
| "print(f\"X in the 1001st data example (a list of images):\")\n", | |||
| "for i, x in enumerate(X_1000):\n", | |||
| " plt.subplot(1, len(X_1000), i+1)\n", | |||
| " plt.axis('off') \n", | |||
| " plt.imshow(x.squeeze(), cmap='gray')\n", | |||
| " plt.subplot(1, len(X_1000), i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "print(f\"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}\")\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}\"\n", | |||
| ")\n", | |||
| "print(f\"Y in the 1001st data example (the computed result): {Y_1000}\")\n", | |||
| "print()\n", | |||
| "X_3000, gt_pseudo_label_3000, Y_3000 = train_X[3000], train_gt_pseudo_label[3000], train_Y[3000]\n", | |||
| "print(f\"X in the 3001st data example (a list of images):\")\n", | |||
| "for i, x in enumerate(X_3000):\n", | |||
| " plt.subplot(1, len(X_3000), i+1)\n", | |||
| " plt.axis('off') \n", | |||
| " plt.imshow(x.squeeze(), cmap='gray')\n", | |||
| " plt.subplot(1, len(X_3000), i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "print(f\"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}\")\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}\"\n", | |||
| ")\n", | |||
| "print(f\"Y in the 3001st data example (the computed result): {Y_3000}\")" | |||
| ] | |||
| }, | |||
| @@ -184,11 +195,15 @@ | |||
| "source": [ | |||
| "data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)]\n", | |||
| "pred_idx = base_model.predict(X=data_instances)\n", | |||
| "print(f\"Predicted class index for a batch of 32 instances: \" +\n", | |||
| " f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\")\n", | |||
| "print(\n", | |||
| " f\"Predicted class index for a batch of 32 instances: \"\n", | |||
| " + f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\"\n", | |||
| ")\n", | |||
| "pred_prob = base_model.predict_proba(X=data_instances)\n", | |||
| "print(f\"Predicted class probabilities for a batch of 32 instances: \" +\n", | |||
| " f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\")" | |||
| "print(\n", | |||
| " f\"Predicted class probabilities for a batch of 32 instances: \"\n", | |||
| " + f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\"\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -221,6 +236,7 @@ | |||
| "outputs": [], | |||
| "source": [ | |||
| "from abl.data.structures import ListData\n", | |||
| "\n", | |||
| "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", | |||
| "data_examples = ListData()\n", | |||
| "# We use the first 1001st and 3001st data examples in the training set as an illustration\n", | |||
| @@ -229,15 +245,19 @@ | |||
| "data_examples.Y = [Y_1000, Y_3000]\n", | |||
| "\n", | |||
| "# Perform prediction on the two data examples\n", | |||
| "# Remind that, in the 1001st data example, the length of the formula is 3, \n", | |||
| "# Remind that, in the 1001st data example, the length of the formula is 3,\n", | |||
| "# while in the 3001st data example, the length of the formula is 5.\n", | |||
| "pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']\n", | |||
| "print(f\"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \\n\" +\n", | |||
| " f\"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, \"+\n", | |||
| " f\"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\\n\")\n", | |||
| "print(f\"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \\n\"\n", | |||
| " f\"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, \" +\n", | |||
| " f\"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.\")" | |||
| "pred_label, pred_prob = model.predict(data_examples)[\"label\"], model.predict(data_examples)[\"prob\"]\n", | |||
| "print(\n", | |||
| " f\"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \\n\"\n", | |||
| " + f\"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, \"\n", | |||
| " + f\"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\\n\"\n", | |||
| ")\n", | |||
| "print(\n", | |||
| " f\"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \\n\"\n", | |||
| " f\"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, \"\n", | |||
| " + f\"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.\"\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -261,7 +281,9 @@ | |||
| "outputs": [], | |||
| "source": [ | |||
| "class HwfKB(KBBase):\n", | |||
| " def __init__(self, pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"]):\n", | |||
| " def __init__(\n", | |||
| " self, pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"]\n", | |||
| " ):\n", | |||
| " super().__init__(pseudo_label_list)\n", | |||
| "\n", | |||
| " def _valid_candidate(self, formula):\n", | |||
| @@ -273,13 +295,14 @@ | |||
| " if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"*\", \"/\"]:\n", | |||
| " return False\n", | |||
| " return True\n", | |||
| " \n", | |||
| "\n", | |||
| " # Implement the deduction function\n", | |||
| " def logic_forward(self, formula):\n", | |||
| " if not self._valid_candidate(formula):\n", | |||
| " return np.inf\n", | |||
| " return eval(\"\".join(formula))\n", | |||
| "\n", | |||
| "\n", | |||
| "kb = HwfKB()" | |||
| ] | |||
| }, | |||
| @@ -113,19 +113,19 @@ def main(): | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HWF example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| @@ -148,7 +148,7 @@ def main(): | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.ground: | |||
| kb = HwfGroundKB() | |||
| @@ -1,3 +1,3 @@ | |||
| from .get_dataset import get_dataset | |||
| __all__ = ["get_dataset"] | |||
| __all__ = ["get_dataset"] | |||
| @@ -5,6 +5,7 @@ from torchvision.transforms import transforms | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| def get_dataset(train=True, get_pseudo_label=True): | |||
| transform = transforms.Compose( | |||
| [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |||
| @@ -88,7 +88,7 @@ def main(): | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = LeNet5(num_classes=10) | |||
| loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) | |||
| @@ -119,7 +119,7 @@ def main(): | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.prolog: | |||
| kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
| @@ -89,23 +89,37 @@ def main(): | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| model=FixMatch(network=LeNet5(), threshold=0.95,lambda_u=1.0,mu=7,T=0.5,epoch=1,num_it_epoch=2**20,num_it_total=2**20,device='cuda') | |||
| model = FixMatch( | |||
| network=LeNet5(), | |||
| threshold=0.95, | |||
| lambda_u=1.0, | |||
| mu=7, | |||
| T=0.5, | |||
| epoch=1, | |||
| num_it_epoch=2**20, | |||
| num_it_total=2**20, | |||
| device="cuda", | |||
| ) | |||
| loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) | |||
| optimizer_dict = dict(optimizer=RMSprop, lr=0.0003, alpha=0.9) | |||
| scheduler_dict = dict(scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200) | |||
| scheduler_dict = dict( | |||
| scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200 | |||
| ) | |||
| converter = ModelConverter() | |||
| base_model = converter.convert_lambdalearn_to_basicnn(model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict) | |||
| base_model = converter.convert_lambdalearn_to_basicnn( | |||
| model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict | |||
| ) | |||
| # Build ABLModel | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.prolog: | |||
| kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
| @@ -87,22 +87,29 @@ | |||
| "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", | |||
| "print()\n", | |||
| "train_X, train_gt_pseudo_label, train_Y = train_data\n", | |||
| "print(f\"Length of X, gt_pseudo_label, Y in train_data: \" +\n", | |||
| " f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\")\n", | |||
| "print(\n", | |||
| " f\"Length of X, gt_pseudo_label, Y in train_data: \"\n", | |||
| " + f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\"\n", | |||
| ")\n", | |||
| "test_X, test_gt_pseudo_label, test_Y = test_data\n", | |||
| "print(f\"Length of X, gt_pseudo_label, Y in test_data: \" +\n", | |||
| " f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\")\n", | |||
| "print(\n", | |||
| " f\"Length of X, gt_pseudo_label, Y in test_data: \"\n", | |||
| " + f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\"\n", | |||
| ")\n", | |||
| "print()\n", | |||
| "\n", | |||
| "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n", | |||
| "print(f\"X is a {type(train_X).__name__}, \" +\n", | |||
| " f\"with each element being a {type(X_0).__name__} \" +\n", | |||
| " f\"of {len(X_0)} {type(X_0[0]).__name__}.\")\n", | |||
| "print(f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \" +\n", | |||
| " f\"with each element being a {type(gt_pseudo_label_0).__name__} \" +\n", | |||
| " f\"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.\")\n", | |||
| "print(f\"Y is a {type(train_Y).__name__}, \" +\n", | |||
| " f\"with each element being a {type(Y_0).__name__}.\")" | |||
| "print(\n", | |||
| " f\"X is a {type(train_X).__name__}, \"\n", | |||
| " + f\"with each element being a {type(X_0).__name__} \"\n", | |||
| " + f\"of {len(X_0)} {type(X_0[0]).__name__}.\"\n", | |||
| ")\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \"\n", | |||
| " + f\"with each element being a {type(gt_pseudo_label_0).__name__} \"\n", | |||
| " + f\"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.\"\n", | |||
| ")\n", | |||
| "print(f\"Y is a {type(train_Y).__name__}, \" + f\"with each element being a {type(Y_0).__name__}.\")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -146,14 +153,16 @@ | |||
| "source": [ | |||
| "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n", | |||
| "print(f\"X in the first data example (a list of two images):\")\n", | |||
| "plt.subplot(1,2,1)\n", | |||
| "plt.axis('off') \n", | |||
| "plt.imshow(X_0[0].squeeze(), cmap='gray')\n", | |||
| "plt.subplot(1,2,2)\n", | |||
| "plt.axis('off') \n", | |||
| "plt.imshow(X_0[1].squeeze(), cmap='gray')\n", | |||
| "plt.subplot(1, 2, 1)\n", | |||
| "plt.axis(\"off\")\n", | |||
| "plt.imshow(X_0[0].squeeze(), cmap=\"gray\")\n", | |||
| "plt.subplot(1, 2, 2)\n", | |||
| "plt.axis(\"off\")\n", | |||
| "plt.imshow(X_0[1].squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "print(f\"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}\")\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}\"\n", | |||
| ")\n", | |||
| "print(f\"Y in the first data example (their sum result): {Y_0}\")" | |||
| ] | |||
| }, | |||
| @@ -219,11 +228,15 @@ | |||
| "source": [ | |||
| "data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]\n", | |||
| "pred_idx = base_model.predict(X=data_instances)\n", | |||
| "print(f\"Predicted class index for a batch of 32 instances: \" +\n", | |||
| " f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\")\n", | |||
| "print(\n", | |||
| " f\"Predicted class index for a batch of 32 instances: \"\n", | |||
| " + f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\"\n", | |||
| ")\n", | |||
| "pred_prob = base_model.predict_proba(X=data_instances)\n", | |||
| "print(f\"Predicted class probabilities for a batch of 32 instances: \" +\n", | |||
| " f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\")" | |||
| "print(\n", | |||
| " f\"Predicted class probabilities for a batch of 32 instances: \"\n", | |||
| " + f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\"\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -268,6 +281,7 @@ | |||
| ], | |||
| "source": [ | |||
| "from abl.data.structures import ListData\n", | |||
| "\n", | |||
| "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", | |||
| "data_examples = ListData()\n", | |||
| "# We use the first 100 data examples in the training set as an illustration\n", | |||
| @@ -276,13 +290,17 @@ | |||
| "data_examples.Y = train_Y[:100]\n", | |||
| "\n", | |||
| "# Perform prediction on the 100 data examples\n", | |||
| "pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']\n", | |||
| "print(f\"Predicted class labels for the 100 data examples: \\n\" +\n", | |||
| " f\"a list of length {len(pred_label)}, and each element is \" +\n", | |||
| " f\"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\\n\")\n", | |||
| "print(f\"Predicted class probabilities for the 100 data examples: \\n\" +\n", | |||
| " f\"a list of length {len(pred_prob)}, and each element is \" +\n", | |||
| " f\"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.\")" | |||
| "pred_label, pred_prob = model.predict(data_examples)[\"label\"], model.predict(data_examples)[\"prob\"]\n", | |||
| "print(\n", | |||
| " f\"Predicted class labels for the 100 data examples: \\n\"\n", | |||
| " + f\"a list of length {len(pred_label)}, and each element is \"\n", | |||
| " + f\"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\\n\"\n", | |||
| ")\n", | |||
| "print(\n", | |||
| " f\"Predicted class probabilities for the 100 data examples: \\n\"\n", | |||
| " + f\"a list of length {len(pred_prob)}, and each element is \"\n", | |||
| " + f\"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.\"\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -313,6 +331,7 @@ | |||
| " def logic_forward(self, nums):\n", | |||
| " return sum(nums)\n", | |||
| "\n", | |||
| "\n", | |||
| "kb = AddKB()" | |||
| ] | |||
| }, | |||
| @@ -4,27 +4,30 @@ import openml | |||
| # Function to load and preprocess the dataset | |||
| def load_and_preprocess_dataset(dataset_id): | |||
| dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False) | |||
| dataset = openml.datasets.get_dataset( | |||
| dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False | |||
| ) | |||
| X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||
| # Convert data types | |||
| for col in X.select_dtypes(include='bool').columns: | |||
| for col in X.select_dtypes(include="bool").columns: | |||
| X[col] = X[col].astype(int) | |||
| y = y.cat.codes.astype(int) | |||
| X, y = X.to_numpy(), y.to_numpy() | |||
| return X, y | |||
| # Function to split data (one shot) | |||
| def split_dataset(X, y, test_size = 0.3): | |||
| def split_dataset(X, y, test_size=0.3): | |||
| # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1) | |||
| label_indices, unlabel_indices, test_indices = [], [], [] | |||
| for class_label in np.unique(y): | |||
| idxs = np.where(y == class_label)[0] | |||
| np.random.shuffle(idxs) | |||
| n_train_unlabel = int((1-test_size)*(len(idxs)-1)) | |||
| n_train_unlabel = int((1 - test_size) * (len(idxs) - 1)) | |||
| label_indices.append(idxs[0]) | |||
| unlabel_indices.extend(idxs[1:1+n_train_unlabel]) | |||
| test_indices.extend(idxs[1+n_train_unlabel:]) | |||
| unlabel_indices.extend(idxs[1 : 1 + n_train_unlabel]) | |||
| test_indices.extend(idxs[1 + n_train_unlabel :]) | |||
| X_label, y_label = X[label_indices], y[label_indices] | |||
| X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices] | |||
| X_test, y_test = X[test_indices], y[test_indices] | |||
| return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test | |||
| return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test | |||
| @@ -11,18 +11,27 @@ class ZooKB(KBBase): | |||
| self.solver = Solver() | |||
| # Load information of Zoo dataset | |||
| dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False) | |||
| X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||
| dataset = openml.datasets.get_dataset( | |||
| dataset_id=62, | |||
| download_data=False, | |||
| download_qualities=False, | |||
| download_features_meta_data=False, | |||
| ) | |||
| X, y, categorical_indicator, attribute_names = dataset.get_data( | |||
| target=dataset.default_target_attribute | |||
| ) | |||
| self.attribute_names = attribute_names | |||
| self.target_names = y.cat.categories.tolist() | |||
| # print("Attribute names are: ", self.attribute_names) | |||
| # print("Target names are: ", self.target_names) | |||
| # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] | |||
| # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] | |||
| # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] # noqa: E501 | |||
| # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] # noqa: E501 | |||
| # Define variables | |||
| for name in self.attribute_names+self.target_names: | |||
| exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules | |||
| for name in self.attribute_names + self.target_names: | |||
| exec( | |||
| f"globals()['{name}'] = Int('{name}')" | |||
| ) # or use dict to create var and modify rules | |||
| # Define rules | |||
| rules = [ | |||
| Implies(milk == 1, mammal == 1), | |||
| @@ -54,11 +63,13 @@ class ZooKB(KBBase): | |||
| Implies(insect == 1, eggs == 1), | |||
| Implies(insect == 1, Not(backbone == 1)), | |||
| Implies(insect == 1, legs == 6), | |||
| Implies(invertebrate == 1, Not(backbone == 1)) | |||
| Implies(invertebrate == 1, Not(backbone == 1)), | |||
| ] | |||
| # Define weights and sum of violated weights | |||
| self.weights = {rule: 1 for rule in rules} | |||
| self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights]) | |||
| self.total_violation_weight = Sum( | |||
| [If(Not(rule), self.weights[rule], 0) for rule in self.weights] | |||
| ) | |||
| def logic_forward(self, pseudo_label, data_point): | |||
| attribute_names, target_names = self.attribute_names, self.target_names | |||
| @@ -69,7 +80,7 @@ class ZooKB(KBBase): | |||
| self.solver.reset() | |||
| for name, value in zip(attribute_names, data_point): | |||
| solver.add(eval(f"{name} == {value}")) | |||
| for cate, name in zip(self.pseudo_label_list,target_names): | |||
| for cate, name in zip(self.pseudo_label_list, target_names): | |||
| value = 1 if (cate == pseudo_label) else 0 | |||
| solver.add(eval(f"{name} == {value}")) | |||
| @@ -14,7 +14,6 @@ from get_dataset import load_and_preprocess_dataset, split_dataset | |||
| from kb import ZooKB | |||
| def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
| pred_prob = data_example.pred_prob | |||
| model_scores = confidence_dist(pred_prob, candidate_idxs) | |||
| @@ -22,19 +21,20 @@ def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
| scores = model_scores + rule_scores | |||
| return scores | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="Zoo example") | |||
| parser.add_argument( | |||
| "--loops", type=int, default=3, help="number of loop iterations (default : 3)" | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the ZOO example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| X, y = load_and_preprocess_dataset(dataset_id=62) | |||
| X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | |||
| label_data = tab_data_to_tuple(X_label, y_label) | |||
| @@ -43,7 +43,7 @@ def main(): | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build base model | |||
| base_model = RandomForestClassifier() | |||
| @@ -52,32 +52,38 @@ def main(): | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| kb = ZooKB() | |||
| # Create reasoner | |||
| reasoner = Reasoner(kb, dist_func=consitency) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] | |||
| ### Bridging learning and reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| # Performing training and testing | |||
| print_log("------- Use labeled data to pretrain the model -----------", logger="current") | |||
| base_model.fit(X_label, y_label) | |||
| print_log("------- Test the initial model -----------", logger="current") | |||
| bridge.test(test_data) | |||
| print_log("------- Use ABL to train the model -----------", logger="current") | |||
| bridge.train(train_data=train_data, label_data=label_data, loops=args.loops, segment_size=len(X_unlabel), save_dir=weights_dir) | |||
| bridge.train( | |||
| train_data=train_data, | |||
| label_data=label_data, | |||
| loops=args.loops, | |||
| segment_size=len(X_unlabel), | |||
| save_dir=weights_dir, | |||
| ) | |||
| print_log("------- Test the final model -----------", logger="current") | |||
| bridge.test(test_data) | |||
| @@ -106,9 +106,9 @@ | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n", | |||
| "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n", | |||
| "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)" | |||
| "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result=0)\n", | |||
| "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result=0)\n", | |||
| "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result=0)" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -240,6 +240,7 @@ | |||
| " scores = model_scores + rule_scores\n", | |||
| " return scores\n", | |||
| "\n", | |||
| "\n", | |||
| "reasoner = Reasoner(kb, dist_func=consitency)" | |||
| ] | |||
| }, | |||
| @@ -338,7 +339,13 @@ | |||
| "print_log(\"------- Test the initial model -----------\", logger=\"current\")\n", | |||
| "bridge.test(test_data)\n", | |||
| "print_log(\"------- Use ABL to train the model -----------\", logger=\"current\")\n", | |||
| "bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n", | |||
| "bridge.train(\n", | |||
| " train_data=train_data,\n", | |||
| " label_data=label_data,\n", | |||
| " loops=3,\n", | |||
| " segment_size=len(X_unlabel),\n", | |||
| " save_dir=weights_dir,\n", | |||
| ")\n", | |||
| "print_log(\"------- Test the final model -----------\", logger=\"current\")\n", | |||
| "bridge.test(test_data)" | |||
| ] | |||
| @@ -200,7 +200,7 @@ def kb_add_ground(): | |||
| @pytest.fixture | |||
| def kb_add_prolog(): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl") | |||
| return kb | |||
| @@ -218,7 +218,7 @@ def kb_hwf2(): | |||
| @pytest.fixture | |||
| def kb_hed(): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| kb = HedKB( | |||
| pseudo_label_list=[1, 0, "+", "="], | |||
| @@ -28,9 +28,13 @@ class TestKBBase(object): | |||
| assert result == ([[0, 2], [1, 1], [2, 0]], [2, 2, 2]) | |||
| def test_abduce_candidates(self, kb_add): | |||
| result = kb_add.abduce_candidates([0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0) | |||
| result = kb_add.abduce_candidates( | |||
| [0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0 | |||
| ) | |||
| assert result == ([[0, 1]], [1]) | |||
| result = kb_add.abduce_candidates([1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0) | |||
| result = kb_add.abduce_candidates( | |||
| [1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0 | |||
| ) | |||
| assert result == ([[1, 0]], [1]) | |||
| @@ -53,19 +57,19 @@ class TestGroundKB(object): | |||
| class TestPrologKB(object): | |||
| def test_init_pl1(self, kb_add_prolog): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| assert kb_add_prolog.pseudo_label_list == list(range(10)) | |||
| assert kb_add_prolog.pl_file == "examples/mnist_add/add.pl" | |||
| def test_init_pl2(self, kb_hed): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| assert kb_hed.pseudo_label_list == [1, 0, "+", "="] | |||
| assert kb_hed.pl_file == "examples/hed/reasoning/learn_add.pl" | |||
| def test_prolog_file_not_exist(self): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| pseudo_label_list = [1, 2] | |||
| non_existing_file = "path/to/non_existing_file.pl" | |||
| @@ -74,13 +78,13 @@ class TestPrologKB(object): | |||
| assert non_existing_file in str(excinfo.value) | |||
| def test_logic_forward_pl1(self, kb_add_prolog): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| result = kb_add_prolog.logic_forward([1, 2]) | |||
| assert result == 3 | |||
| def test_logic_forward_pl2(self, kb_hed): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| consist_exs = [ | |||
| [1, 1, "+", 0, "=", 1, 1], | |||
| @@ -97,7 +101,7 @@ class TestPrologKB(object): | |||
| assert kb_hed.logic_forward(inconsist_exs) is False | |||
| def test_revise_at_idx(self, kb_add_prolog): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| result = kb_add_prolog.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0]) | |||
| assert result == ([[0, 2]], [2]) | |||
| @@ -113,34 +117,34 @@ class TestReaonser(object): | |||
| assert 'Valid options for predefined dist_func include "hamming" and "confidence"' in str( | |||
| excinfo.value | |||
| ) | |||
| def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results): | |||
| cost_list = [np.random.rand() for _ in candidates] | |||
| return cost_list | |||
| def test_user_defined_dist_func(self, kb_add): | |||
| reasoner = Reasoner(kb_add, self.random_dist) | |||
| assert reasoner.dist_func == self.random_dist | |||
| def invalid_dist1(self, candidates): | |||
| cost_list = np.array([np.random.rand() for _ in candidates]) | |||
| return cost_list | |||
| def invalid_dist2(self, data_example, candidates, candidate_idxs, reasoning_results): | |||
| cost_list = np.array([np.random.rand() for _ in candidates]) | |||
| return np.append(cost_list, np.random.rand()) | |||
| def test_invalid_user_defined_dist_func(self, kb_add, data_examples_add): | |||
| with pytest.raises(ValueError) as excinfo: | |||
| Reasoner(kb_add, self.invalid_dist1) | |||
| assert 'User-defined dist_func must have exactly four parameters' in str( | |||
| excinfo.value | |||
| ) | |||
| assert "User-defined dist_func must have exactly four parameters" in str(excinfo.value) | |||
| with pytest.raises(ValueError) as excinfo: | |||
| reasoner = Reasoner(kb_add, self.invalid_dist2) | |||
| reasoner.batch_abduce(data_examples_add) | |||
| assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str( | |||
| excinfo.value | |||
| assert ( | |||
| "The length of the array returned by dist_func must be " | |||
| + "equal to the number of candidates" | |||
| in str(excinfo.value) | |||
| ) | |||
| @@ -186,7 +190,7 @@ class TestBatchAbduce(object): | |||
| ] | |||
| def test_batch_abduce_prolog(self, kb_add_prolog, data_examples_add): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) | |||
| reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1) | |||
| @@ -208,7 +212,7 @@ class TestBatchAbduce(object): | |||
| ] | |||
| def test_batch_abduce_zoopt(self, kb_add_prolog, data_examples_add): | |||
| if platform.system() == 'Darwin': | |||
| if platform.system() == "Darwin": | |||
| return | |||
| reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) | |||
| reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2) | |||