diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 97c7349..7386a88 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -21,4 +21,4 @@ jobs: uses: py-actions/flake8@v2 with: max-line-length: "100" - args: --ignore=E203,W503 \ No newline at end of file + args: --ignore=E203,W503,F821,E266 \ No newline at end of file diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 9fb22d4..f8a1c9a 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -164,7 +164,8 @@ class SimpleBridge(BaseBridge): self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData] ) -> ListData: """ - Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model. + Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data + examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model. Parameters ---------- @@ -212,18 +213,19 @@ class SimpleBridge(BaseBridge): Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. - ``X`` is a list of sublists representing the input data. - - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but not - to train. ``gt_pseudo_label`` can be ``None``. - - ``Y`` is a list representing the ground truth reasoning result for each sublist in ``X``. + - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but + not to train. ``gt_pseudo_label`` can be ``None``. + - ``Y`` is a list representing the ground truth reasoning result for each sublist + in ``X``. label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional Labeled data should be in the same format as ``train_data``. The only difference is that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be utilized to train the model. Defaults to None. - val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional + val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label`` and ``Y`` can be either None or not, which depends on the evaluation metircs in - ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate the - model during training time. Defaults to None. + ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate + the model during training time. Defaults to None. loops : int Machine Learning part and Reasoning part will be iteratively optimized for ``loops`` times, by default 50. @@ -325,7 +327,7 @@ class SimpleBridge(BaseBridge): Parameters ---------- - val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] + val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. @@ -344,10 +346,10 @@ class SimpleBridge(BaseBridge): Parameters ---------- - test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] - Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object with ``X``, - ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be either None or - not, which depends on the evaluation metircs in ``self.metric_list``. + test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 + Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object + with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` + can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. """ print_log("Test start:", logger="current") test_data_examples = self.data_preprocess("test", test_data) diff --git a/abl/data/data_converter.py b/abl/data/data_converter.py index 4673b16..b4e495d 100644 --- a/abl/data/data_converter.py +++ b/abl/data/data_converter.py @@ -4,21 +4,22 @@ from abl.utils import tab_data_to_tuple from .structures.list_data import ListData from lambdaLearn.Base.TabularMixin import TabularMixin + class DataConverter: - ''' + """ This class provides functionality to convert LambdaLearn data to ABL-Package data. - ''' + """ + def __init__(self) -> None: pass - + def convert_lambdalearn_to_tuple( - self, - dataset: TabularMixin, - reasoning_result: Any + self, dataset: TabularMixin, reasoning_result: Any ) -> Tuple[Tuple, Tuple, Tuple, Tuple]: - ''' - Convert a lambdalearn dataset to a tuple of tuples (label_data, train_data, valid_data, test_data), each containing (data, label, reasoning_result). - + """ + Convert a lambdalearn dataset to a tuple of tuples (label_data, train_data, valid_data, test_data), # noqa: E501 + each containing (data, label, reasoning_result). + Parameters ---------- dataset : TabularMixin @@ -28,27 +29,38 @@ class DataConverter: Returns ------- Tuple[Tuple, Tuple, Tuple, Tuple] - A tuple of (label_data, train_data, valid_data, test_data), where each element is a tuple of (data, label, reasoning_result). - ''' - + A tuple of (label_data, train_data, valid_data, test_data), where each element is + a tuple of (data, label, reasoning_result). + """ + if not isinstance(dataset, TabularMixin): - raise NotImplementedError("Only support converting the datasets that are instances of TabularMixin. Please refer to the documentation and manually convert the dataset into a tuple. ") - - label_data = tab_data_to_tuple(dataset.labeled_X, dataset.labeled_y, reasoning_result=reasoning_result) - train_data = tab_data_to_tuple(dataset.unlabeled_X, dataset.unlabeled_y, reasoning_result=reasoning_result) - valid_data = tab_data_to_tuple(dataset.valid_X, dataset.valid_y, reasoning_result=reasoning_result) - test_data = tab_data_to_tuple(dataset.test_X, dataset.test_y, reasoning_result=reasoning_result) - + raise NotImplementedError( + "Only support converting the datasets that are instances of TabularMixin. " + + "Please refer to the documentation and manually convert the dataset into a tuple." + ) + + label_data = tab_data_to_tuple( + dataset.labeled_X, dataset.labeled_y, reasoning_result=reasoning_result + ) + train_data = tab_data_to_tuple( + dataset.unlabeled_X, dataset.unlabeled_y, reasoning_result=reasoning_result + ) + valid_data = tab_data_to_tuple( + dataset.valid_X, dataset.valid_y, reasoning_result=reasoning_result + ) + test_data = tab_data_to_tuple( + dataset.test_X, dataset.test_y, reasoning_result=reasoning_result + ) + return label_data, train_data, valid_data, test_data - + def convert_lambdalearn_to_listdata( - self, - dataset: TabularMixin, - reasoning_result: Any + self, dataset: TabularMixin, reasoning_result: Any ) -> Tuple[ListData, ListData, ListData, ListData]: - ''' - Convert a lambdalearn dataset to a tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples). - + """ + Convert a lambdalearn dataset to a tuple of ListData + (label_data_examples, train_data_examples, valid_data_examples, test_data_examples). + Parameters ---------- dataset : TabularMixin @@ -58,14 +70,20 @@ class DataConverter: Returns ------- Tuple[ListData, ListData, ListData, ListData] - A tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples) - ''' - + A tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples) # noqa: E501 + """ + if not isinstance(dataset, TabularMixin): - raise NotImplementedError("Only support converting the datasets that are instances of TabularMixin. Please refer to the documentation and manually convert the dataset into a ListData. ") - - label_data, train_data, valid_data, test_data = self.convert_lambdalearn_to_tuple(dataset, reasoning_result) - + raise NotImplementedError( + "Only support converting the datasets that are instances of TabularMixin. " + + "Please refer to the documentation and manually convert the dataset " + + "into a ListData." + ) + + label_data, train_data, valid_data, test_data = self.convert_lambdalearn_to_tuple( + dataset, reasoning_result + ) + if label_data is not None: X, gt_pseudo_label, Y = label_data label_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) @@ -78,23 +96,46 @@ class DataConverter: if test_data is not None: X, gt_pseudo_label, Y = test_data test_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) - + return label_data_examples, train_data_examples, valid_data_examples, test_data_examples - -if __name__ == '__main__': + + +if __name__ == "__main__": from lambdaLearn.Dataset.Tabular.BreastCancer import BreastCancer - breast_dataset=BreastCancer(labeled_size=0.1, stratified=True, shuffle=True) + + breast_dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True) dataconverter = DataConverter() - - label_data, train_data, valid_data, test_data = dataconverter.convert_lambdalearn_to_tuple(breast_dataset, 0) - print(type(label_data).__name__, type(train_data).__name__, type(valid_data).__name__, type(test_data).__name__) + + label_data, train_data, valid_data, test_data = dataconverter.convert_lambdalearn_to_tuple( + breast_dataset, 0 + ) + print( + type(label_data).__name__, + type(train_data).__name__, + type(valid_data).__name__, + type(test_data).__name__, + ) print(len(label_data)) print(len(label_data[0]), len(label_data[1]), len(label_data[2])) print(label_data[0][0], label_data[1][0], label_data[2][0]) print() - - label_data_examples, train_data_examples, valid_data_examples, test_data_examples = dataconverter.convert_lambdalearn_to_listdata(breast_dataset, 0) - print(type(label_data_examples).__name__, type(train_data_examples).__name__, type(valid_data_examples).__name__, type(test_data_examples).__name__) - print(len(label_data_examples.X), len(label_data_examples.gt_pseudo_label), len(label_data_examples.Y)) + + ( + label_data_examples, + train_data_examples, + valid_data_examples, + test_data_examples, + ) = dataconverter.convert_lambdalearn_to_listdata(breast_dataset, 0) + print( + type(label_data_examples).__name__, + type(train_data_examples).__name__, + type(valid_data_examples).__name__, + type(test_data_examples).__name__, + ) + print( + len(label_data_examples.X), + len(label_data_examples.gt_pseudo_label), + len(label_data_examples.Y), + ) label_data_example = label_data_examples[0] - print(label_data_example.X, label_data_example.gt_pseudo_label, label_data_example.Y) \ No newline at end of file + print(label_data_example.X, label_data_example.gt_pseudo_label, label_data_example.Y) diff --git a/abl/data/evaluation/reasoning_metric.py b/abl/data/evaluation/reasoning_metric.py index 3368bd3..9a010bb 100644 --- a/abl/data/evaluation/reasoning_metric.py +++ b/abl/data/evaluation/reasoning_metric.py @@ -38,7 +38,8 @@ class ReasoningMetric(BaseMetric): """ Process a batch of data examples. - This method takes in a batch of data examples, each containing predicted pseudo-labels(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It + This method takes in a batch of data examples, each containing predicted pseudo-labels + (pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It evaluates the reasoning accuracy of each example by comparing the logical reasoning result (derived using the knowledge base) of the predicted pseudo-labels against Y The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended diff --git a/abl/data/structures/list_data.py b/abl/data/structures/list_data.py index 3849246..e3c6fa1 100644 --- a/abl/data/structures/list_data.py +++ b/abl/data/structures/list_data.py @@ -53,7 +53,7 @@ class ListData(BaseDataElement): ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``. This design is inspired by and extends the functionalities of the ``BaseDataElement`` - class implemented in `MMEngine `_. + class implemented in `MMEngine `_. # noqa: E501 Examples: >>> from abl.data.structures import ListData @@ -71,7 +71,7 @@ class ListData(BaseDataElement): DATA FIELDS Y: [1, 2, 3] gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] - X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] + X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 ) at 0x7f3bbf1991c0> >>> print(data_examples[:1]) BasicNN: """ @@ -271,8 +272,8 @@ class BasicNN: return torch.cat(results, axis=0) def predict( - self, - data_loader: Optional[DataLoader] = None, + self, + data_loader: Optional[DataLoader] = None, X: Optional[List[Any]] = None, ) -> numpy.ndarray: """ @@ -312,8 +313,8 @@ class BasicNN: return self._predict(data_loader).argmax(axis=1).cpu().numpy() def predict_proba( - self, - data_loader: Optional[DataLoader] = None, + self, + data_loader: Optional[DataLoader] = None, X: Optional[List[Any]] = None, ) -> numpy.ndarray: """ @@ -403,9 +404,9 @@ class BasicNN: return mean_loss, accuracy def score( - self, - data_loader: Optional[DataLoader] = None, - X: Optional[List[Any]] = None, + self, + data_loader: Optional[DataLoader] = None, + X: Optional[List[Any]] = None, y: Optional[List[int]] = None, ) -> float: """ @@ -447,8 +448,8 @@ class BasicNN: def _data_loader( self, - X: Optional[List[Any]], - y: Optional[List[int]] = None, + X: Optional[List[Any]], + y: Optional[List[int]] = None, shuffle: Optional[bool] = True, ) -> DataLoader: """ diff --git a/abl/learning/model_converter.py b/abl/learning/model_converter.py index 5e662df..0f79ce9 100644 --- a/abl/learning/model_converter.py +++ b/abl/learning/model_converter.py @@ -6,16 +6,18 @@ from .abl_model import ABLModel from .basic_nn import BasicNN from lambdaLearn.Base.DeepModelMixin import DeepModelMixin + class ModelConverter: - ''' + """ This class provides functionality to convert LambdaLearn models to ABL-Package models. - ''' + """ + def __init__(self) -> None: pass - + def convert_lambdalearn_to_ablmodel( self, - lambdalearn_model, + lambdalearn_model, loss_fn: torch.nn.Module, optimizer_dict: dict, scheduler_dict: Optional[dict] = None, @@ -28,11 +30,13 @@ class ModelConverter: save_dir: Optional[str] = None, train_transform: Callable[..., Any] = None, test_transform: Callable[..., Any] = None, - collate_fn: Callable[[List[Any]], Any] = None + collate_fn: Callable[[List[Any]], Any] = None, ): - ''' - Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn model should implement fit and predict methods. - + """ + Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of + DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn + model should implement ``fit`` and ``predict`` methods. + Parameters ---------- lambdalearn_model : Union[DeepModelMixin, Any] @@ -75,17 +79,34 @@ class ModelConverter: ------- ABLModel The converted ABLModel instance. - ''' + """ if isinstance(lambdalearn_model, DeepModelMixin): - base_model = self.convert_lambdalearn_to_basicnn(lambdalearn_model, loss_fn, optimizer_dict, scheduler_dict, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn) + base_model = self.convert_lambdalearn_to_basicnn( + lambdalearn_model, + loss_fn, + optimizer_dict, + scheduler_dict, + device, + batch_size, + num_epochs, + stop_loss, + num_workers, + save_interval, + save_dir, + train_transform, + test_transform, + collate_fn, + ) return ABLModel(base_model) - + if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")): - raise NotImplementedError("The lambdalearn_model should be an instance of DeepModelMixin, or implement fit and predict methods.") - + raise NotImplementedError( + "The lambdalearn_model should be an instance of DeepModelMixin, or implement " + + "fit and predict methods." + ) + return ABLModel(lambdalearn_model) - - + def convert_lambdalearn_to_basicnn( self, lambdalearn_model: DeepModelMixin, @@ -103,9 +124,10 @@ class ModelConverter: test_transform: Callable[..., Any] = None, collate_fn: Callable[[List[Any]], Any] = None, ): - ''' - Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN. - + """ + Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of + DeepModelMixin, its network will be used as the model of BasicNN. + Parameters ---------- lambdalearn_model : Union[DeepModelMixin, Any] @@ -147,10 +169,13 @@ class ModelConverter: ------- BasicNN The converted BasicNN instance. - ''' + """ if isinstance(lambdalearn_model, DeepModelMixin): if not isinstance(lambdalearn_model.network, torch.nn.Module): - raise NotImplementedError(f"Expected lambdalearn_model.network to be a torch.nn.Module, but got {type(lambdalearn_model.network)}") + raise NotImplementedError( + "Expected lambdalearn_model.network to be a torch.nn.Module, " + + f"but got {type(lambdalearn_model.network)}" + ) # Only use the network part and device of the lambdalearn model network = copy.deepcopy(lambdalearn_model.network) optimizer_class = optimizer_dict["optimizer"] @@ -163,7 +188,24 @@ class ModelConverter: else: scheduler = None device = lambdalearn_model.device if device is None else device - base_model = BasicNN(model=network, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, device=device, batch_size=batch_size, num_epochs=num_epochs, stop_loss=stop_loss, num_workers=num_workers, save_interval=save_interval, save_dir=save_dir, train_transform=train_transform, test_transform=test_transform, collate_fn=collate_fn) + base_model = BasicNN( + model=network, + loss_fn=loss_fn, + optimizer=optimizer, + scheduler=scheduler, + device=device, + batch_size=batch_size, + num_epochs=num_epochs, + stop_loss=stop_loss, + num_workers=num_workers, + save_interval=save_interval, + save_dir=save_dir, + train_transform=train_transform, + test_transform=test_transform, + collate_fn=collate_fn, + ) return base_model else: - raise NotImplementedError("The lambdalearn_model should be an instance of DeepModelMixin.") \ No newline at end of file + raise NotImplementedError( + "The lambdalearn_model should be an instance of DeepModelMixin." + ) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 0632781..8e31ab0 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -26,7 +26,7 @@ class KBBase(ABC): list so that each aligns with its corresponding index in the base model: the first with the 0th index, the second with the 1st, and so forth. max_err : float, optional - The upper tolerance limit when comparing the similarity between the reasoning result of + The upper tolerance limit when comparing the similarity between the reasoning result of pseudo-labels and the ground truth. This is only applicable when the reasoning result is of a numerical type. This is particularly relevant for regression problems where exact matches might not be feasible. Defaults to 1e-10. @@ -65,10 +65,12 @@ class KBBase(ABC): self.use_cache = use_cache self.key_func = key_func self.cache_size = cache_size - + argspec = inspect.getfullargspec(self.logic_forward) self._num_args = len(argspec.args) - 1 - if self._num_args==2 and self.use_cache: # If the logic_forward function has 2 arguments, then disable cache + if ( + self._num_args == 2 and self.use_cache + ): # If the logic_forward function has 2 arguments, then disable cache self.use_cache = False print_log( "The logic_forward function has 2 arguments, so the cache is disabled. ", @@ -89,10 +91,10 @@ class KBBase(ABC): pseudo_label : List[Any] Pseudo-labels of an example. x : List[Any], optional - The example. If deductive logical reasoning does not require any - information from the example, the overridden function provided by the user can omit + The example. If deductive logical reasoning does not require any + information from the example, the overridden function provided by the user can omit this parameter. - + Returns ------- Any @@ -100,11 +102,11 @@ class KBBase(ABC): """ def abduce_candidates( - self, - pseudo_label: List[Any], - y: Any, - x: List[Any], - max_revision_num: int, + self, + pseudo_label: List[Any], + y: Any, + x: List[Any], + max_revision_num: int, require_more_revision: int, ) -> List[List[Any]]: """ @@ -118,7 +120,7 @@ class KBBase(ABC): Ground truth of the reasoning result for the example. x : List[Any] The example. If the information from the example - is not required in the reasoning process, then this parameter will not have + is not required in the reasoning process, then this parameter will not have any effect. max_revision_num : int The upper limit on the number of revised labels for each example. @@ -129,9 +131,9 @@ class KBBase(ABC): ------- Tuple[List[List[Any]], List[Any]] A tuple of two element. The first element is a list of candidate revisions, i.e. revised - pseudo-labels of the example. that are compatible with the knowledge base. The second element is - a list of reasoning results corresponding to each candidate, i.e., the outcome of the - logic_forward function. + pseudo-labels of the example. that are compatible with the knowledge base. The second + element is a list of reasoning results corresponding to each candidate, i.e., the + outcome of the ``logic_forward`` function. """ return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision) @@ -154,10 +156,10 @@ class KBBase(ABC): return reasoning_result == y def revise_at_idx( - self, - pseudo_label: List[Any], - y: Any, - x: List[Any], + self, + pseudo_label: List[Any], + y: Any, + x: List[Any], revision_idx: List[int], ) -> List[List[Any]]: """ @@ -171,7 +173,7 @@ class KBBase(ABC): Ground truth of the reasoning result for the example. x : List[Any] The example. If the information from the example - is not required in the reasoning process, then this parameter will not have + is not required in the reasoning process, then this parameter will not have any effect. revision_idx : List[int] A list specifying indices of where revisions should be made to the pseudo-labels. @@ -180,9 +182,9 @@ class KBBase(ABC): ------- Tuple[List[List[Any]], List[Any]] A tuple of two element. The first element is a list of candidate revisions, i.e. revised - pseudo-labels of the example that are compatible with the knowledge base. The second element is - a list of reasoning results corresponding to each candidate, i.e., the outcome of the - logic_forward function. + pseudo-labels of the example that are compatible with the knowledge base. The second + element is a list of reasoning results corresponding to each candidate, i.e., the + outcome of the ``logic_forward`` function. """ candidates, reasoning_results = [], [] abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) @@ -192,14 +194,15 @@ class KBBase(ABC): candidate[idx] = c[i] reasoning_result = self.logic_forward(candidate, *(x,) if self._num_args == 2 else ()) if self._check_equal(reasoning_result, y): - candidates.append(candidate); reasoning_results.append(reasoning_result) + candidates.append(candidate) + reasoning_results.append(reasoning_result) return candidates, reasoning_results def _revision( - self, - revision_num: int, - pseudo_label: List[Any], - y: Any, + self, + revision_num: int, + pseudo_label: List[Any], + y: Any, x: List[Any], ) -> List[List[Any]]: """ @@ -210,16 +213,17 @@ class KBBase(ABC): revision_idx_list = combinations(range(len(pseudo_label)), revision_num) for revision_idx in revision_idx_list: candidates, reasoning_results = self.revise_at_idx(pseudo_label, y, x, revision_idx) - new_candidates.extend(candidates); new_reasoning_results.extend(reasoning_results) + new_candidates.extend(candidates) + new_reasoning_results.extend(reasoning_results) return new_candidates, new_reasoning_results @abl_cache() def _abduce_by_search( - self, - pseudo_label: List[Any], - y: Any, - x: List[Any], - max_revision_num: int, + self, + pseudo_label: List[Any], + y: Any, + x: List[Any], + max_revision_num: int, require_more_revision: int, ) -> List[List[Any]]: """ @@ -235,7 +239,7 @@ class KBBase(ABC): Ground truth of the reasoning result for the example. x : List[Any] The example. If the information from the example - is not required in the reasoning process, then this parameter will not have + is not required in the reasoning process, then this parameter will not have any effect. max_revision_num : int The upper limit on the number of revisions. @@ -248,14 +252,15 @@ class KBBase(ABC): ------- Tuple[List[List[Any]], List[Any]] A tuple of two element. The first element is a list of candidate revisions, i.e. revised - pseudo-labels of the example that are compatible with the knowledge base. The second element is - a list of reasoning results corresponding to each candidate, i.e., the outcome of the - logic_forward function. + pseudo-labels of the example that are compatible with the knowledge base. The second + element is a list of reasoning results corresponding to each candidate, i.e., the + outcome of the ``logic_forward`` function. """ candidates, reasoning_results = [], [] for revision_num in range(len(pseudo_label) + 1): new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x) - candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results) + candidates.extend(new_candidates) + reasoning_results.extend(new_reasoning_results) if len(candidates) > 0: min_revision_num = revision_num break @@ -268,7 +273,8 @@ class KBBase(ABC): if revision_num > max_revision_num: return candidates, reasoning_results new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x) - candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results) + candidates.extend(new_candidates) + reasoning_results.extend(new_reasoning_results) return candidates, reasoning_results def __repr__(self): @@ -305,16 +311,19 @@ class GroundKB(KBBase): """ def __init__( - self, - pseudo_label_list: List[Any], - GKB_len_list: List[int], + self, + pseudo_label_list: List[Any], + GKB_len_list: List[int], max_err: float = 1e-10, ): super().__init__(pseudo_label_list, max_err) if not isinstance(GKB_len_list, list): raise TypeError("GKB_len_list should be list, but got {type(GKB_len_list)}") - if self._num_args==2: - raise NotImplementedError(f"GroundKB only supports 1-argument logic_forward, but got {self._num_args}-argument logic_forward") + if self._num_args == 2: + raise NotImplementedError( + "GroundKB only supports 1-argument logic_forward, but got " + + f"{self._num_args}-argument logic_forward" + ) self.GKB_len_list = GKB_len_list self.GKB = {} X, Y = self._get_GKB() @@ -354,11 +363,11 @@ class GroundKB(KBBase): return X, Y def abduce_candidates( - self, - pseudo_label: List[Any], - y: Any, - x: List[Any], - max_revision_num: int, + self, + pseudo_label: List[Any], + y: Any, + x: List[Any], + max_revision_num: int, require_more_revision: int, ) -> List[List[Any]]: """ @@ -383,9 +392,9 @@ class GroundKB(KBBase): ------- Tuple[List[List[Any]], List[Any]] A tuple of two element. The first element is a list of candidate revisions, i.e. revised - pseudo-labels of THE example that are compatible with the knowledge base. The second element is - a list of reasoning results corresponding to each candidate, i.e., the outcome of the - logic_forward function. + pseudo-labels of the example that are compatible with the knowledge base. The second + element is a list of reasoning results corresponding to each candidate, i.e., the + outcome of the ``logic_forward`` function. """ if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list: return [], [] @@ -418,7 +427,8 @@ class GroundKB(KBBase): all_candidates, all_reasoning_results = [], [] for key in key_list[low_key:high_key]: for candidate in potential_candidates[key]: - all_candidates.append(candidate); all_reasoning_results.append(key) + all_candidates.append(candidate) + all_reasoning_results.append(key) else: all_candidates = self.GKB[len(pseudo_label)][y] all_reasoning_results = [y] * len(all_candidates) @@ -468,14 +478,17 @@ class PrologKB(KBBase): def __init__(self, pseudo_label_list: List[Any], pl_file: str): super().__init__(pseudo_label_list) - + try: import pyswip except (IndexError, ImportError): - print("A Prolog-based knowledge base is in use. Please install Swi-Prolog \ - using the command 'sudo apt-get install swi-prolog' for Linux users, \ - or download it following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md for Windows and Mac users.") - + print( + "A Prolog-based knowledge base is in use. Please install Swi-Prolog using the" + + "command 'sudo apt-get install swi-prolog' for Linux users, or download it " + + "following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md " + + "for Windows and Mac users." + ) + self.prolog = pyswip.Prolog() self.pl_file = pl_file if not os.path.exists(self.pl_file): @@ -519,9 +532,9 @@ class PrologKB(KBBase): return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label)) def get_query_string( - self, - pseudo_label: List[Any], - y: Any, + self, + pseudo_label: List[Any], + y: Any, x: List[Any], revision_idx: List[int], ) -> str: @@ -538,8 +551,8 @@ class PrologKB(KBBase): y : Any Ground truth of the reasoning result for the example. x : List[Any] - The corresponding input example. If the information from the input - is not required in the reasoning process, then this parameter will not have + The corresponding input example. If the information from the input + is not required in the reasoning process, then this parameter will not have any effect. revision_idx : List[int] A list specifying indices of where revisions should be made to the pseudo-labels. @@ -556,10 +569,10 @@ class PrologKB(KBBase): return query_string def revise_at_idx( - self, - pseudo_label: List[Any], - y: Any, - x: List[Any], + self, + pseudo_label: List[Any], + y: Any, + x: List[Any], revision_idx: List[int], ) -> List[List[Any]]: """ @@ -572,8 +585,8 @@ class PrologKB(KBBase): y : Any Ground truth of the reasoning result for the example. x : List[Any] - The corresponding input example. If the information from the input - is not required in the reasoning process, then this parameter will not have + The corresponding input example. If the information from the input + is not required in the reasoning process, then this parameter will not have any effect. revision_idx : List[int] A list specifying indices of where revisions should be made to the pseudo-labels. @@ -581,12 +594,10 @@ class PrologKB(KBBase): Returns ------- Tuple[List[List[Any]], List[Any]] - A list of candidates, i.e. revised pseudo-labels of the example that are compatible with the - knowledge base. A tuple of two element. The first element is a list of candidate revisions, i.e. revised - pseudo-labels of the example that are compatible with the knowledge base. The second element is - a list of reasoning results corresponding to each candidate, i.e., the outcome of the - logic_forward function. + pseudo-labels of the example that are compatible with the knowledge base. The second + element is a list of reasoning results corresponding to each candidate, i.e., the + outcome of the ``logic_forward`` function. """ candidates, reasoning_results = [], [] query_string = self.get_query_string(pseudo_label, y, x, revision_idx) @@ -598,7 +609,8 @@ class PrologKB(KBBase): for i, idx in enumerate(revision_idx): candidate[idx] = c[i] candidate = reform_list(candidate, save_pseudo_label) - candidates.append(candidate); reasoning_results.append(y) + candidates.append(candidate) + reasoning_results.append(y) return candidates, reasoning_results def __repr__(self): diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 082214d..22d1690 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -28,10 +28,10 @@ class Reasoner: candidate, 'confidence': calculates the distance between the prediction and each candidate based on confidence derived from the predicted probability in the data example. The callable function should have the signature - dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element - in this cost list should be a numerical value representing the cost for each - candidate, and the list should have the same length as candidates. - Defaults to 'confidence'. + dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must + return a cost list. Each element in this cost list should be a numerical value + representing the cost for each candidate, and the list should have the same length + as candidates. Defaults to 'confidence'. idx_to_label : dict, optional A mapping from index in the base model to label. If not provided, a default order-based index to label mapping is created. Defaults to None. @@ -76,14 +76,16 @@ class Reasoner: if isinstance(dist_func, str): if dist_func not in ["hamming", "confidence"]: raise NotImplementedError( - f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.' + 'Valid options for predefined dist_func include "hamming" ' + + f'and "confidence", but got {dist_func}.' ) return elif callable(dist_func): params = inspect.signature(dist_func).parameters.values() if len(params) != 4: raise ValueError( - f"User-defined dist_func must have exactly four parameters, but got {len(params)}." + "User-defined dist_func must have exactly four parameters, " + + f"but got {len(params)}." ) return else: @@ -99,7 +101,8 @@ class Reasoner: raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.") if value not in self.kb.pseudo_label_list: raise ValueError( - f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}." + "All values in the idx_to_label must be in the pseudo_label_list, " + + f"but got {value}." ) def _get_one_candidate( @@ -169,8 +172,8 @@ class Reasoner: cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) if len(cost_list) != len(candidates): raise ValueError( - f"The length of the array returned by dist_func must be equal to the number of candidates. " - f"Expected length {len(candidates)}, but got {len(cost_list)}." + "The length of the array returned by dist_func must be equal to the number " + + f"of candidates. Expected length {len(candidates)}, but got {len(cost_list)}." ) return cost_list @@ -204,7 +207,9 @@ class Reasoner: dim=dimension, constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), ) - parameter = Parameter(budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True) + parameter = Parameter( + budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True + ) solution = Opt.min(objective, parameter) return solution @@ -240,29 +245,28 @@ class Reasoner: return np.min(self._get_cost_list(data_example, candidates, reasoning_results)) else: return symbol_num - + def zoopt_budget(self, symbol_num: int) -> int: """ - Set the budget for ZOOpt optimization. The function, in its default implementation, - returns a fixed budget value of 100. However, it can be adjusted to return other fixed - values, or a dynamic budget based on the number of symbols, if desired. For example, one might choose to - set the budget as 100 times symbol_num. + Set the budget for ZOOpt optimization. The function, in its default implementation, + returns a fixed budget value of 100. However, it can be adjusted to return other fixed + values, or a dynamic budget based on the number of symbols, if desired. For example, + one might choose to set the budget as 100 times ``symbol_num``. Parameters ---------- symbol_num : int - The number of symbols to be considered in the ZOOpt optimization process. Although this parameter - can be used to compute a dynamic optimization budget, by default it is not utilized in the - calculation. + The number of symbols to be considered in the ZOOpt optimization process. Although this + parameter can be used to compute a dynamic optimization budget, by default it is not + utilized in the calculation. Returns ------- int - The budget for ZOOpt optimization. By default, this is a fixed value of 100, + The budget for ZOOpt optimization. By default, this is a fixed value of 100, irrespective of the symbol_num value. """ return 100 - def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: """ @@ -284,7 +288,8 @@ class Reasoner: elif isinstance(max_revision, float): if not (0 <= max_revision <= 1): raise ValueError( - f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}" + "If max_revision is a float, it must be between 0 and 1, " + + f"but got {max_revision}" ) return round(symbol_num * max_revision) else: diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index 9cfd590..65c5337 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,6 +1,13 @@ from .cache import Cache, abl_cache from .logger import ABLLogger, print_log -from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple +from .utils import ( + confidence_dist, + flatten, + hamming_dist, + reform_list, + to_hashable, + tab_data_to_tuple, +) __all__ = [ "Cache", diff --git a/abl/utils/cache.py b/abl/utils/cache.py index 9804c5f..3687982 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -43,7 +43,7 @@ class Cache(Generic[K, T]): def get_from_dict(self, obj, *args) -> T: """Implements dict based cache.""" # x is not used in cache key - pred_pseudo_label, y, x, *res_args = args + pred_pseudo_label, y, x, *res_args = args cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) link = self.cache_dict.get(cache_key) if link is not None: diff --git a/abl/utils/logger.py b/abl/utils/logger.py index 4e9d8b6..298e9f6 100644 --- a/abl/utils/logger.py +++ b/abl/utils/logger.py @@ -168,9 +168,11 @@ class ABLLogger(Logger, ManagerMixin): Notes ----- - The ``name`` of the logger and the ``instance_name`` of ``ABLLogger`` could be different. - ``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not ``logging.getLogger``. - This ensures ``ABLLogger`` is not influenced by third-party logging configurations. - - Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without ``Handler``. + ``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not + ``logging.getLogger``. This ensures ``ABLLogger`` is not influenced by third-party logging + configurations. + - Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without + ``Handler``. Examples -------- @@ -288,15 +290,16 @@ class ABLLogger(Logger, ManagerMixin): def print_log( - msg, - logger: Optional[Union[Logger, str]] = None, + msg, + logger: Optional[Union[Logger, str]] = None, level: Optional[int] = logging.INFO, ) -> None: """ Print a log message using the specified logger or a default method. This function logs a message with a given logger, if provided, or prints it using - the standard ``print`` function. It supports special logger types such as 'silent' and 'current'. + the standard ``print`` function. It supports special logger types such as 'silent' + and 'current'. Parameters ---------- @@ -308,8 +311,8 @@ def print_log( method is used. - 'silent': No message will be printed. - 'current': Use the latest created logger to log the message. - - other str: The instance name of the logger. A ``ValueError`` is raised if the logger has not - been created. + - other str: The instance name of the logger. A ``ValueError`` is raised if the logger has + not been created. - None: The ``print()`` method is used for logging. level : int, optional The logging level. This is only applicable when ``logger`` is a Logger object, 'current', diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 72535dd..66e83f8 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -15,7 +15,7 @@ def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[A Returns ------- List[Any] - A flattened version of the input list, where only the first + A flattened version of the input list, where only the first level of sublists and tuples are reduced. """ if not isinstance(nested_list, list): @@ -24,15 +24,15 @@ def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[A flattened_list = [] for item in nested_list: if isinstance(item, (list, tuple)): - flattened_list.extend(item) + flattened_list.extend(item) else: flattened_list.append(item) return flattened_list + def reform_list( - flattened_list: List[Any], - structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]] + flattened_list: List[Any], structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]] ) -> List[List[Any]]: """ Reform the list based on the structure of ``structured_list``. @@ -148,16 +148,15 @@ def restore_from_hashable(x): return [restore_from_hashable(item) for item in x] return x + def tab_data_to_tuple( - X: Union[List[Any], Any], - y: Union[List[Any], Any], - reasoning_result: Optional[Any] = 0 + X: Union[List[Any], Any], y: Union[List[Any], Any], reasoning_result: Optional[Any] = 0 ) -> Tuple[List[List[Any]], List[List[Any]], List[Any]]: - ''' - Convert a tabular data to a tuple by adding a dimension to each element of + """ + Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result. If X is None, return None. - + Parameters ---------- X : Union[List[Any], Any] @@ -166,14 +165,16 @@ def tab_data_to_tuple( The label. reasoning_result : Any, optional The reasoning result, by default 0. - + Returns ------- Tuple[List[List[Any]], List[List[Any]], List[Any]] A tuple of (data, label, reasoning_result). - ''' + """ if X is None: return None if len(X) != len(y): - raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) - return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) \ No newline at end of file + raise ValueError( + "The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)) + ) + return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) diff --git a/docs/conf.py b/docs/conf.py index 60fe957..dcd522d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,27 +5,28 @@ import sys from docutils import nodes from docutils.parsers.rst import roles -import re from sphinx.application import Sphinx + def remove_noqa(app: Sphinx, what: str, name: str, obj, options, lines): new_lines = [] for line in lines: - new_line = re.sub(r'\s*#\s*noqa.*$', '', line) + new_line = re.sub(r"\s*#\s*noqa.*$", "", line) new_lines.append(new_line) lines[:] = new_lines + def colored_text_role(role, rawtext, text, lineno, inliner, options={}, content=[]): node = nodes.inline(rawtext, text, classes=[role]) return [node], [] -roles.register_local_role('green-bold', colored_text_role) -roles.register_local_role('blue-bold', colored_text_role) -roles.register_local_role('yellow-bold', colored_text_role) -roles.register_local_role('green', colored_text_role) -roles.register_local_role('blue', colored_text_role) -roles.register_local_role('yellow', colored_text_role) +roles.register_local_role("green-bold", colored_text_role) +roles.register_local_role("blue-bold", colored_text_role) +roles.register_local_role("yellow-bold", colored_text_role) +roles.register_local_role("green", colored_text_role) +roles.register_local_role("blue", colored_text_role) +roles.register_local_role("yellow", colored_text_role) if "READTHEDOCS" not in os.environ: @@ -45,7 +46,7 @@ author = "Author" extensions = [ "sphinx.ext.intersphinx", "sphinx.ext.autodoc", - 'sphinx.ext.autosummary', + "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.viewcode", "sphinx_rtd_theme", @@ -95,7 +96,8 @@ texinfo_documents = [ def setup(app): from sphinx.domains.python import PyField from sphinx.util.docfields import Field - app.connect('autodoc-process-docstring', remove_noqa) + + app.connect("autodoc-process-docstring", remove_noqa) app.add_object_type( "confval", "confval", diff --git a/examples/hed/bridge.py b/examples/hed/bridge.py index d81d829..0706786 100644 --- a/examples/hed/bridge.py +++ b/examples/hed/bridge.py @@ -247,7 +247,7 @@ class HedBridge(SimpleBridge): logger="current", ) self.model.load( - load_path=os.path.join(save_dir, f"pretrain_weights.pth") + load_path=os.path.join(save_dir, "pretrain_weights.pth") ) else: self.model.load( diff --git a/examples/hed/datasets/get_dataset.py b/examples/hed/datasets/get_dataset.py index 61c6df5..5d02b0c 100644 --- a/examples/hed/datasets/get_dataset.py +++ b/examples/hed/datasets/get_dataset.py @@ -12,16 +12,20 @@ from torchvision.transforms import transforms CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + def download_and_unzip(url, zip_file_name): try: gdown.download(url, zip_file_name) - with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: + with zipfile.ZipFile(zip_file_name, "r") as zip_ref: zip_ref.extractall(CURRENT_DIR) os.remove(zip_file_name) except Exception as e: if os.path.exists(zip_file_name): os.remove(zip_file_name) - raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hed/datasets' folder") + raise Exception( + f"An error occurred during download or unzip: {e}. Instead, you can download " + + f"the dataset from {url} and unzip it in 'examples/hed/datasets' folder" + ) def get_pretrain_data(labels, image_size=(28, 28, 1)): @@ -33,7 +37,7 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)): img_path_list = os.listdir(label_path) for img_path in img_path_list: with Image.open(osp.join(label_path, img_path)) as img: - img = img.convert('L') + img = img.convert("L") img = img.resize((image_size[1], image_size[0])) img_array = np.array(img, dtype=np.float32) normalized_img = (img_array - 127) / 128.0 @@ -72,19 +76,19 @@ def split_equation(equations_by_len, prop_train, prop_val): def get_dataset(dataset="mnist", train=True): - data_dir = CURRENT_DIR + '/mnist_images' - + data_dir = CURRENT_DIR + "/mnist_images" + if not os.path.exists(data_dir): print("Dataset not exist, downloading it...") - url = 'https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download' + url = "https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download" download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip")) print("Download and extraction complete.") - + if train: file = os.path.join(data_dir, "expr_train.json") else: file = os.path.join(data_dir, "expr_test.json") - + if dataset == "mnist": file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") elif dataset == "random": @@ -94,7 +98,7 @@ def get_dataset(dataset="mnist", train=True): with open(file, "rb") as f: img_dataset = pickle.load(f) - + X, Y = [], [] if train: positive = img_dataset["train:positive"] @@ -117,4 +121,3 @@ def get_dataset(dataset="mnist", train=True): equations_by_len = divide_equations_by_len(X, Y) return equations_by_len - diff --git a/examples/hed/hed.ipynb b/examples/hed/hed.ipynb index 5d8bb8e..10d2549 100644 --- a/examples/hed/hed.ipynb +++ b/examples/hed/hed.ipynb @@ -86,31 +86,39 @@ "source": [ "true_train_equation = train_data[1]\n", "false_train_equation = train_data[0]\n", - "print(f\"Equations in the dataset is organized by equation length, \" +\n", - " f\"from {min(train_data[0].keys())} to {max(train_data[0].keys())}\")\n", + "print(\n", + " f\"Equations in the dataset is organized by equation length, \"\n", + " + f\"from {min(train_data[0].keys())} to {max(train_data[0].keys())}\"\n", + ")\n", "print()\n", "\n", "true_train_equation_with_length_5 = true_train_equation[5]\n", "false_train_equation_with_length_5 = false_train_equation[5]\n", - "print(f\"For each euqation length, there are {len(true_train_equation_with_length_5)} \" +\n", - " f\"true equation and {len(false_train_equation_with_length_5)} false equation \" +\n", - " f\"in the training set\")\n", + "print(\n", + " f\"For each euqation length, there are {len(true_train_equation_with_length_5)} \"\n", + " + f\"true equation and {len(false_train_equation_with_length_5)} false equation \"\n", + " + f\"in the training set\"\n", + ")\n", "\n", "true_val_equation = val_data[1]\n", "false_val_equation = val_data[0]\n", "true_val_equation_with_length_5 = true_val_equation[5]\n", "false_val_equation_with_length_5 = false_val_equation[5]\n", - "print(f\"For each euqation length, there are {len(true_val_equation_with_length_5)} \" +\n", - " f\"true equation and {len(false_val_equation_with_length_5)} false equation \" +\n", - " f\"in the validation set\")\n", + "print(\n", + " f\"For each euqation length, there are {len(true_val_equation_with_length_5)} \"\n", + " + f\"true equation and {len(false_val_equation_with_length_5)} false equation \"\n", + " + f\"in the validation set\"\n", + ")\n", "\n", "true_test_equation = test_data[1]\n", "false_test_equation = test_data[0]\n", "true_test_equation_with_length_5 = true_test_equation[5]\n", "false_test_equation_with_length_5 = false_test_equation[5]\n", - "print(f\"For each euqation length, there are {len(true_test_equation_with_length_5)} \" +\n", - " f\"true equation and {len(false_test_equation_with_length_5)} false equation \" +\n", - " f\"in the test set\")" + "print(\n", + " f\"For each euqation length, there are {len(true_test_equation_with_length_5)} \"\n", + " + f\"true equation and {len(false_test_equation_with_length_5)} false equation \"\n", + " + f\"in the test set\"\n", + ")" ] }, { @@ -199,30 +207,30 @@ "true_train_equation_with_length_8 = true_train_equation[8]\n", "print(f\"First true equation with length 5 in the training dataset:\")\n", "for i, x in enumerate(true_train_equation_with_length_5[0]):\n", - " plt.subplot(1, 5, i+1)\n", - " plt.axis('off') \n", - " plt.imshow(x.squeeze(), cmap='gray')\n", + " plt.subplot(1, 5, i + 1)\n", + " plt.axis(\"off\")\n", + " plt.imshow(x.squeeze(), cmap=\"gray\")\n", "plt.show()\n", "print(f\"First true equation with length 8 in the training dataset:\")\n", "for i, x in enumerate(true_train_equation_with_length_8[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off') \n", - " plt.imshow(x.squeeze(), cmap='gray')\n", + " plt.subplot(1, 8, i + 1)\n", + " plt.axis(\"off\")\n", + " plt.imshow(x.squeeze(), cmap=\"gray\")\n", "plt.show()\n", "\n", "false_train_equation_with_length_5 = false_train_equation[5]\n", "false_train_equation_with_length_8 = false_train_equation[8]\n", "print(f\"First false equation with length 5 in the training dataset:\")\n", "for i, x in enumerate(false_train_equation_with_length_5[0]):\n", - " plt.subplot(1, 5, i+1)\n", - " plt.axis('off') \n", - " plt.imshow(x.squeeze(), cmap='gray')\n", + " plt.subplot(1, 5, i + 1)\n", + " plt.axis(\"off\")\n", + " plt.imshow(x.squeeze(), cmap=\"gray\")\n", "plt.show()\n", "print(f\"First false equation with length 8 in the training dataset:\")\n", "for i, x in enumerate(false_train_equation_with_length_8[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off') \n", - " plt.imshow(x.squeeze(), cmap='gray')\n", + " plt.subplot(1, 8, i + 1)\n", + " plt.axis(\"off\")\n", + " plt.imshow(x.squeeze(), cmap=\"gray\")\n", "plt.show()" ] }, diff --git a/examples/hed/main.py b/examples/hed/main.py index 984ff5c..d66d197 100644 --- a/examples/hed/main.py +++ b/examples/hed/main.py @@ -46,20 +46,20 @@ def main(): ) args = parser.parse_args() - + # Build logger print_log("Abductive Learning on the HED example.", logger="current") ### Working with Data print_log("Working with Data.", logger="current") - + total_train_data = get_dataset(train=True) train_data, val_data = split_equation(total_train_data, 3, 1) test_data = get_dataset(train=False) ### Building the Learning Part print_log("Building the Learning Part.", logger="current") - + # Build necessary components for BasicNN cls = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() @@ -83,7 +83,7 @@ def main(): ### Building the Reasoning Part print_log("Building the Reasoning Part.", logger="current") - + # Build knowledge base kb = HedKB() diff --git a/examples/hed/reasoning/__init__.py b/examples/hed/reasoning/__init__.py index 52b7e39..6fdae78 100644 --- a/examples/hed/reasoning/__init__.py +++ b/examples/hed/reasoning/__init__.py @@ -1,3 +1,3 @@ from .reasoning import HedKB, HedReasoner -__all__ = ["HedKB", "HedReasoner"] \ No newline at end of file +__all__ = ["HedKB", "HedReasoner"] diff --git a/examples/hed/reasoning/reasoning.py b/examples/hed/reasoning/reasoning.py index 43cb9d7..02910ef 100644 --- a/examples/hed/reasoning/reasoning.py +++ b/examples/hed/reasoning/reasoning.py @@ -8,8 +8,11 @@ from abl.utils import reform_list CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + class HedKB(PrologKB): - def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): + def __init__( + self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl") + ): pl_file = pl_file.replace("\\", "/") super().__init__(pseudo_label_list, pl_file) self.learned_rules = {} @@ -34,7 +37,7 @@ class HedReasoner(Reasoner): data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx ) return candidate - + def zoopt_budget(self, symbol_num): return 200 @@ -53,7 +56,7 @@ class HedReasoner(Reasoner): max_candidate_idxs = [] found = False for idx in range(-1, len(data_example.pred_idx)): - if (not idx in idxs) and (idx >= 0): + if (idx not in idxs) and (idx >= 0): idxs.append(idx) candidates, _ = self.revise_at_idx(data_example[idxs]) if len(candidates) == 0: @@ -96,4 +99,4 @@ class HedReasoner(Reasoner): return abduced_pseudo_label def abduce_rules(self, pred_res): - return self.kb.abduce_rules(pred_res) \ No newline at end of file + return self.kb.abduce_rules(pred_res) diff --git a/examples/hwf/datasets/__init__.py b/examples/hwf/datasets/__init__.py index a2e06bd..5d3ddab 100644 --- a/examples/hwf/datasets/__init__.py +++ b/examples/hwf/datasets/__init__.py @@ -1,3 +1,3 @@ from .get_dataset import get_dataset -__all__ = ["get_dataset"] \ No newline at end of file +__all__ = ["get_dataset"] diff --git a/examples/hwf/datasets/get_dataset.py b/examples/hwf/datasets/get_dataset.py index d89b1e3..f8ce374 100644 --- a/examples/hwf/datasets/get_dataset.py +++ b/examples/hwf/datasets/get_dataset.py @@ -10,26 +10,31 @@ CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) + def download_and_unzip(url, zip_file_name): try: gdown.download(url, zip_file_name) - with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: + with zipfile.ZipFile(zip_file_name, "r") as zip_ref: zip_ref.extractall(CURRENT_DIR) os.remove(zip_file_name) except Exception as e: if os.path.exists(zip_file_name): os.remove(zip_file_name) - raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hwf/datasets' folder") + raise Exception( + f"An error occurred during download or unzip: {e}. Instead, you can download " + + f"the dataset from {url} and unzip it in 'examples/hwf/datasets' folder" + ) + def get_dataset(train=True, get_pseudo_label=False): - data_dir = CURRENT_DIR + '/data' - + data_dir = CURRENT_DIR + "/data" + if not os.path.exists(data_dir): print("Dataset not exist, downloading it...") - url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download' + url = "https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download" download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip")) print("Download and extraction complete.") - + if train: file = os.path.join(data_dir, "expr_train.json") else: @@ -59,4 +64,4 @@ def get_dataset(train=True, get_pseudo_label=False): pseudo_label.append(imgs_pseudo_label) Y.append(data[idx]["res"]) - return X, pseudo_label, Y + return X, pseudo_label, Y diff --git a/examples/hwf/hwf.ipynb b/examples/hwf/hwf.ipynb index d4a5f5a..db140f2 100644 --- a/examples/hwf/hwf.ipynb +++ b/examples/hwf/hwf.ipynb @@ -72,21 +72,28 @@ "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", "print()\n", "train_X, train_gt_pseudo_label, train_Y = train_data\n", - "print(f\"Length of X, gt_pseudo_label, Y in train_data: \" +\n", - " f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\")\n", + "print(\n", + " f\"Length of X, gt_pseudo_label, Y in train_data: \"\n", + " + f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\"\n", + ")\n", "test_X, test_gt_pseudo_label, test_Y = test_data\n", - "print(f\"Length of X, gt_pseudo_label, Y in test_data: \" +\n", - " f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\")\n", + "print(\n", + " f\"Length of X, gt_pseudo_label, Y in test_data: \"\n", + " + f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\"\n", + ")\n", "print()\n", "\n", "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n", - "print(f\"X is a {type(train_X).__name__}, \" +\n", - " f\"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.\")\n", - "print(f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \" +\n", - " f\"with each element being a {type(gt_pseudo_label_0).__name__} \" +\n", - " f\"of {type(gt_pseudo_label_0[0]).__name__}.\")\n", - "print(f\"Y is a {type(train_Y).__name__}, \" +\n", - " f\"with each element being a {type(Y_0).__name__}.\")" + "print(\n", + " f\"X is a {type(train_X).__name__}, \"\n", + " + f\"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.\"\n", + ")\n", + "print(\n", + " f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \"\n", + " + f\"with each element being a {type(gt_pseudo_label_0).__name__} \"\n", + " + f\"of {type(gt_pseudo_label_0[0]).__name__}.\"\n", + ")\n", + "print(f\"Y is a {type(train_Y).__name__}, \" + f\"with each element being a {type(Y_0).__name__}.\")" ] }, { @@ -105,21 +112,25 @@ "X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000]\n", "print(f\"X in the 1001st data example (a list of images):\")\n", "for i, x in enumerate(X_1000):\n", - " plt.subplot(1, len(X_1000), i+1)\n", - " plt.axis('off') \n", - " plt.imshow(x.squeeze(), cmap='gray')\n", + " plt.subplot(1, len(X_1000), i + 1)\n", + " plt.axis(\"off\")\n", + " plt.imshow(x.squeeze(), cmap=\"gray\")\n", "plt.show()\n", - "print(f\"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}\")\n", + "print(\n", + " f\"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}\"\n", + ")\n", "print(f\"Y in the 1001st data example (the computed result): {Y_1000}\")\n", "print()\n", "X_3000, gt_pseudo_label_3000, Y_3000 = train_X[3000], train_gt_pseudo_label[3000], train_Y[3000]\n", "print(f\"X in the 3001st data example (a list of images):\")\n", "for i, x in enumerate(X_3000):\n", - " plt.subplot(1, len(X_3000), i+1)\n", - " plt.axis('off') \n", - " plt.imshow(x.squeeze(), cmap='gray')\n", + " plt.subplot(1, len(X_3000), i + 1)\n", + " plt.axis(\"off\")\n", + " plt.imshow(x.squeeze(), cmap=\"gray\")\n", "plt.show()\n", - "print(f\"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}\")\n", + "print(\n", + " f\"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}\"\n", + ")\n", "print(f\"Y in the 3001st data example (the computed result): {Y_3000}\")" ] }, @@ -184,11 +195,15 @@ "source": [ "data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)]\n", "pred_idx = base_model.predict(X=data_instances)\n", - "print(f\"Predicted class index for a batch of 32 instances: \" +\n", - " f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\")\n", + "print(\n", + " f\"Predicted class index for a batch of 32 instances: \"\n", + " + f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\"\n", + ")\n", "pred_prob = base_model.predict_proba(X=data_instances)\n", - "print(f\"Predicted class probabilities for a batch of 32 instances: \" +\n", - " f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\")" + "print(\n", + " f\"Predicted class probabilities for a batch of 32 instances: \"\n", + " + f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\"\n", + ")" ] }, { @@ -221,6 +236,7 @@ "outputs": [], "source": [ "from abl.data.structures import ListData\n", + "\n", "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", "data_examples = ListData()\n", "# We use the first 1001st and 3001st data examples in the training set as an illustration\n", @@ -229,15 +245,19 @@ "data_examples.Y = [Y_1000, Y_3000]\n", "\n", "# Perform prediction on the two data examples\n", - "# Remind that, in the 1001st data example, the length of the formula is 3, \n", + "# Remind that, in the 1001st data example, the length of the formula is 3,\n", "# while in the 3001st data example, the length of the formula is 5.\n", - "pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']\n", - "print(f\"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \\n\" +\n", - " f\"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, \"+\n", - " f\"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\\n\")\n", - "print(f\"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \\n\"\n", - " f\"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, \" +\n", - " f\"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.\")" + "pred_label, pred_prob = model.predict(data_examples)[\"label\"], model.predict(data_examples)[\"prob\"]\n", + "print(\n", + " f\"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \\n\"\n", + " + f\"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, \"\n", + " + f\"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\\n\"\n", + ")\n", + "print(\n", + " f\"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \\n\"\n", + " f\"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, \"\n", + " + f\"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.\"\n", + ")" ] }, { @@ -261,7 +281,9 @@ "outputs": [], "source": [ "class HwfKB(KBBase):\n", - " def __init__(self, pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"]):\n", + " def __init__(\n", + " self, pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"]\n", + " ):\n", " super().__init__(pseudo_label_list)\n", "\n", " def _valid_candidate(self, formula):\n", @@ -273,13 +295,14 @@ " if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"*\", \"/\"]:\n", " return False\n", " return True\n", - " \n", + "\n", " # Implement the deduction function\n", " def logic_forward(self, formula):\n", " if not self._valid_candidate(formula):\n", " return np.inf\n", " return eval(\"\".join(formula))\n", "\n", + "\n", "kb = HwfKB()" ] }, diff --git a/examples/hwf/main.py b/examples/hwf/main.py index 83c60e9..f8e10d9 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -113,19 +113,19 @@ def main(): ) args = parser.parse_args() - + # Build logger print_log("Abductive Learning on the HWF example.", logger="current") ### Working with Data print_log("Working with Data.", logger="current") - + train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) ### Building the Learning Part print_log("Building the Learning Part.", logger="current") - + # Build necessary components for BasicNN cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) loss_fn = nn.CrossEntropyLoss() @@ -148,7 +148,7 @@ def main(): ### Building the Reasoning Part print_log("Building the Reasoning Part.", logger="current") - + # Build knowledge base if args.ground: kb = HwfGroundKB() diff --git a/examples/mnist_add/datasets/__init__.py b/examples/mnist_add/datasets/__init__.py index a2e06bd..5d3ddab 100644 --- a/examples/mnist_add/datasets/__init__.py +++ b/examples/mnist_add/datasets/__init__.py @@ -1,3 +1,3 @@ from .get_dataset import get_dataset -__all__ = ["get_dataset"] \ No newline at end of file +__all__ = ["get_dataset"] diff --git a/examples/mnist_add/datasets/get_dataset.py b/examples/mnist_add/datasets/get_dataset.py index 53423da..bfa7b93 100644 --- a/examples/mnist_add/datasets/get_dataset.py +++ b/examples/mnist_add/datasets/get_dataset.py @@ -5,6 +5,7 @@ from torchvision.transforms import transforms CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + def get_dataset(train=True, get_pseudo_label=True): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index 0616fc5..025f9d4 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -88,7 +88,7 @@ def main(): ### Building the Learning Part print_log("Building the Learning Part.", logger="current") - + # Build necessary components for BasicNN cls = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) @@ -119,7 +119,7 @@ def main(): ### Building the Reasoning Part print_log("Building the Reasoning Part.", logger="current") - + # Build knowledge base if args.prolog: kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") diff --git a/examples/mnist_add/main_with_model_converter.py b/examples/mnist_add/main_with_model_converter.py index 2fc582d..8f3cde4 100644 --- a/examples/mnist_add/main_with_model_converter.py +++ b/examples/mnist_add/main_with_model_converter.py @@ -89,23 +89,37 @@ def main(): ### Building the Learning Part print_log("Building the Learning Part.", logger="current") - + # Build necessary components for BasicNN - model=FixMatch(network=LeNet5(), threshold=0.95,lambda_u=1.0,mu=7,T=0.5,epoch=1,num_it_epoch=2**20,num_it_total=2**20,device='cuda') + model = FixMatch( + network=LeNet5(), + threshold=0.95, + lambda_u=1.0, + mu=7, + T=0.5, + epoch=1, + num_it_epoch=2**20, + num_it_total=2**20, + device="cuda", + ) loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) optimizer_dict = dict(optimizer=RMSprop, lr=0.0003, alpha=0.9) - scheduler_dict = dict(scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200) + scheduler_dict = dict( + scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200 + ) converter = ModelConverter() - base_model = converter.convert_lambdalearn_to_basicnn(model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict) + base_model = converter.convert_lambdalearn_to_basicnn( + model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict + ) # Build ABLModel model = ABLModel(base_model) ### Building the Reasoning Part print_log("Building the Reasoning Part.", logger="current") - + # Build knowledge base if args.prolog: kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") diff --git a/examples/mnist_add/mnist_add.ipynb b/examples/mnist_add/mnist_add.ipynb index af2de04..b5e323a 100644 --- a/examples/mnist_add/mnist_add.ipynb +++ b/examples/mnist_add/mnist_add.ipynb @@ -87,22 +87,29 @@ "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", "print()\n", "train_X, train_gt_pseudo_label, train_Y = train_data\n", - "print(f\"Length of X, gt_pseudo_label, Y in train_data: \" +\n", - " f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\")\n", + "print(\n", + " f\"Length of X, gt_pseudo_label, Y in train_data: \"\n", + " + f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\"\n", + ")\n", "test_X, test_gt_pseudo_label, test_Y = test_data\n", - "print(f\"Length of X, gt_pseudo_label, Y in test_data: \" +\n", - " f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\")\n", + "print(\n", + " f\"Length of X, gt_pseudo_label, Y in test_data: \"\n", + " + f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\"\n", + ")\n", "print()\n", "\n", "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n", - "print(f\"X is a {type(train_X).__name__}, \" +\n", - " f\"with each element being a {type(X_0).__name__} \" +\n", - " f\"of {len(X_0)} {type(X_0[0]).__name__}.\")\n", - "print(f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \" +\n", - " f\"with each element being a {type(gt_pseudo_label_0).__name__} \" +\n", - " f\"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.\")\n", - "print(f\"Y is a {type(train_Y).__name__}, \" +\n", - " f\"with each element being a {type(Y_0).__name__}.\")" + "print(\n", + " f\"X is a {type(train_X).__name__}, \"\n", + " + f\"with each element being a {type(X_0).__name__} \"\n", + " + f\"of {len(X_0)} {type(X_0[0]).__name__}.\"\n", + ")\n", + "print(\n", + " f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \"\n", + " + f\"with each element being a {type(gt_pseudo_label_0).__name__} \"\n", + " + f\"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.\"\n", + ")\n", + "print(f\"Y is a {type(train_Y).__name__}, \" + f\"with each element being a {type(Y_0).__name__}.\")" ] }, { @@ -146,14 +153,16 @@ "source": [ "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n", "print(f\"X in the first data example (a list of two images):\")\n", - "plt.subplot(1,2,1)\n", - "plt.axis('off') \n", - "plt.imshow(X_0[0].squeeze(), cmap='gray')\n", - "plt.subplot(1,2,2)\n", - "plt.axis('off') \n", - "plt.imshow(X_0[1].squeeze(), cmap='gray')\n", + "plt.subplot(1, 2, 1)\n", + "plt.axis(\"off\")\n", + "plt.imshow(X_0[0].squeeze(), cmap=\"gray\")\n", + "plt.subplot(1, 2, 2)\n", + "plt.axis(\"off\")\n", + "plt.imshow(X_0[1].squeeze(), cmap=\"gray\")\n", "plt.show()\n", - "print(f\"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}\")\n", + "print(\n", + " f\"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}\"\n", + ")\n", "print(f\"Y in the first data example (their sum result): {Y_0}\")" ] }, @@ -219,11 +228,15 @@ "source": [ "data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]\n", "pred_idx = base_model.predict(X=data_instances)\n", - "print(f\"Predicted class index for a batch of 32 instances: \" +\n", - " f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\")\n", + "print(\n", + " f\"Predicted class index for a batch of 32 instances: \"\n", + " + f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\"\n", + ")\n", "pred_prob = base_model.predict_proba(X=data_instances)\n", - "print(f\"Predicted class probabilities for a batch of 32 instances: \" +\n", - " f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\")" + "print(\n", + " f\"Predicted class probabilities for a batch of 32 instances: \"\n", + " + f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\"\n", + ")" ] }, { @@ -268,6 +281,7 @@ ], "source": [ "from abl.data.structures import ListData\n", + "\n", "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", "data_examples = ListData()\n", "# We use the first 100 data examples in the training set as an illustration\n", @@ -276,13 +290,17 @@ "data_examples.Y = train_Y[:100]\n", "\n", "# Perform prediction on the 100 data examples\n", - "pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']\n", - "print(f\"Predicted class labels for the 100 data examples: \\n\" +\n", - " f\"a list of length {len(pred_label)}, and each element is \" +\n", - " f\"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\\n\")\n", - "print(f\"Predicted class probabilities for the 100 data examples: \\n\" +\n", - " f\"a list of length {len(pred_prob)}, and each element is \" +\n", - " f\"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.\")" + "pred_label, pred_prob = model.predict(data_examples)[\"label\"], model.predict(data_examples)[\"prob\"]\n", + "print(\n", + " f\"Predicted class labels for the 100 data examples: \\n\"\n", + " + f\"a list of length {len(pred_label)}, and each element is \"\n", + " + f\"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\\n\"\n", + ")\n", + "print(\n", + " f\"Predicted class probabilities for the 100 data examples: \\n\"\n", + " + f\"a list of length {len(pred_prob)}, and each element is \"\n", + " + f\"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.\"\n", + ")" ] }, { @@ -313,6 +331,7 @@ " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", + "\n", "kb = AddKB()" ] }, diff --git a/examples/zoo/get_dataset.py b/examples/zoo/get_dataset.py index 600b338..18c10fd 100644 --- a/examples/zoo/get_dataset.py +++ b/examples/zoo/get_dataset.py @@ -4,27 +4,30 @@ import openml # Function to load and preprocess the dataset def load_and_preprocess_dataset(dataset_id): - dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False) + dataset = openml.datasets.get_dataset( + dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False + ) X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute) # Convert data types - for col in X.select_dtypes(include='bool').columns: + for col in X.select_dtypes(include="bool").columns: X[col] = X[col].astype(int) y = y.cat.codes.astype(int) X, y = X.to_numpy(), y.to_numpy() return X, y + # Function to split data (one shot) -def split_dataset(X, y, test_size = 0.3): +def split_dataset(X, y, test_size=0.3): # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1) label_indices, unlabel_indices, test_indices = [], [], [] for class_label in np.unique(y): idxs = np.where(y == class_label)[0] np.random.shuffle(idxs) - n_train_unlabel = int((1-test_size)*(len(idxs)-1)) + n_train_unlabel = int((1 - test_size) * (len(idxs) - 1)) label_indices.append(idxs[0]) - unlabel_indices.extend(idxs[1:1+n_train_unlabel]) - test_indices.extend(idxs[1+n_train_unlabel:]) + unlabel_indices.extend(idxs[1 : 1 + n_train_unlabel]) + test_indices.extend(idxs[1 + n_train_unlabel :]) X_label, y_label = X[label_indices], y[label_indices] X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices] X_test, y_test = X[test_indices], y[test_indices] - return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test \ No newline at end of file + return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test diff --git a/examples/zoo/kb.py b/examples/zoo/kb.py index 4757184..86b1886 100644 --- a/examples/zoo/kb.py +++ b/examples/zoo/kb.py @@ -11,18 +11,27 @@ class ZooKB(KBBase): self.solver = Solver() # Load information of Zoo dataset - dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False) - X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute) + dataset = openml.datasets.get_dataset( + dataset_id=62, + download_data=False, + download_qualities=False, + download_features_meta_data=False, + ) + X, y, categorical_indicator, attribute_names = dataset.get_data( + target=dataset.default_target_attribute + ) self.attribute_names = attribute_names self.target_names = y.cat.categories.tolist() # print("Attribute names are: ", self.attribute_names) # print("Target names are: ", self.target_names) - # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] - # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] + # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] # noqa: E501 + # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] # noqa: E501 # Define variables - for name in self.attribute_names+self.target_names: - exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules + for name in self.attribute_names + self.target_names: + exec( + f"globals()['{name}'] = Int('{name}')" + ) # or use dict to create var and modify rules # Define rules rules = [ Implies(milk == 1, mammal == 1), @@ -54,11 +63,13 @@ class ZooKB(KBBase): Implies(insect == 1, eggs == 1), Implies(insect == 1, Not(backbone == 1)), Implies(insect == 1, legs == 6), - Implies(invertebrate == 1, Not(backbone == 1)) + Implies(invertebrate == 1, Not(backbone == 1)), ] # Define weights and sum of violated weights self.weights = {rule: 1 for rule in rules} - self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights]) + self.total_violation_weight = Sum( + [If(Not(rule), self.weights[rule], 0) for rule in self.weights] + ) def logic_forward(self, pseudo_label, data_point): attribute_names, target_names = self.attribute_names, self.target_names @@ -69,7 +80,7 @@ class ZooKB(KBBase): self.solver.reset() for name, value in zip(attribute_names, data_point): solver.add(eval(f"{name} == {value}")) - for cate, name in zip(self.pseudo_label_list,target_names): + for cate, name in zip(self.pseudo_label_list, target_names): value = 1 if (cate == pseudo_label) else 0 solver.add(eval(f"{name} == {value}")) diff --git a/examples/zoo/main.py b/examples/zoo/main.py index b4da2d1..093976d 100644 --- a/examples/zoo/main.py +++ b/examples/zoo/main.py @@ -14,7 +14,6 @@ from get_dataset import load_and_preprocess_dataset, split_dataset from kb import ZooKB - def consitency(data_example, candidates, candidate_idxs, reasoning_results): pred_prob = data_example.pred_prob model_scores = confidence_dist(pred_prob, candidate_idxs) @@ -22,19 +21,20 @@ def consitency(data_example, candidates, candidate_idxs, reasoning_results): scores = model_scores + rule_scores return scores + def main(): parser = argparse.ArgumentParser(description="Zoo example") parser.add_argument( "--loops", type=int, default=3, help="number of loop iterations (default : 3)" ) args = parser.parse_args() - + # Build logger print_log("Abductive Learning on the ZOO example.", logger="current") ### Working with Data print_log("Working with Data.", logger="current") - + X, y = load_and_preprocess_dataset(dataset_id=62) X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) label_data = tab_data_to_tuple(X_label, y_label) @@ -43,7 +43,7 @@ def main(): ### Building the Learning Part print_log("Building the Learning Part.", logger="current") - + # Build base model base_model = RandomForestClassifier() @@ -52,32 +52,38 @@ def main(): ### Building the Reasoning Part print_log("Building the Reasoning Part.", logger="current") - + # Build knowledge base kb = ZooKB() - + # Create reasoner reasoner = Reasoner(kb, dist_func=consitency) ### Building Evaluation Metrics print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] - + ### Bridging learning and reasoning print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) - + # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") - + # Performing training and testing print_log("------- Use labeled data to pretrain the model -----------", logger="current") base_model.fit(X_label, y_label) print_log("------- Test the initial model -----------", logger="current") bridge.test(test_data) print_log("------- Use ABL to train the model -----------", logger="current") - bridge.train(train_data=train_data, label_data=label_data, loops=args.loops, segment_size=len(X_unlabel), save_dir=weights_dir) + bridge.train( + train_data=train_data, + label_data=label_data, + loops=args.loops, + segment_size=len(X_unlabel), + save_dir=weights_dir, + ) print_log("------- Test the final model -----------", logger="current") bridge.test(test_data) diff --git a/examples/zoo/zoo.ipynb b/examples/zoo/zoo.ipynb index bf21f43..b7effea 100644 --- a/examples/zoo/zoo.ipynb +++ b/examples/zoo/zoo.ipynb @@ -106,9 +106,9 @@ "metadata": {}, "outputs": [], "source": [ - "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n", - "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n", - "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)" + "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result=0)\n", + "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result=0)\n", + "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result=0)" ] }, { @@ -240,6 +240,7 @@ " scores = model_scores + rule_scores\n", " return scores\n", "\n", + "\n", "reasoner = Reasoner(kb, dist_func=consitency)" ] }, @@ -338,7 +339,13 @@ "print_log(\"------- Test the initial model -----------\", logger=\"current\")\n", "bridge.test(test_data)\n", "print_log(\"------- Use ABL to train the model -----------\", logger=\"current\")\n", - "bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n", + "bridge.train(\n", + " train_data=train_data,\n", + " label_data=label_data,\n", + " loops=3,\n", + " segment_size=len(X_unlabel),\n", + " save_dir=weights_dir,\n", + ")\n", "print_log(\"------- Test the final model -----------\", logger=\"current\")\n", "bridge.test(test_data)" ] diff --git a/tests/conftest.py b/tests/conftest.py index f96d9a4..2b71466 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -200,7 +200,7 @@ def kb_add_ground(): @pytest.fixture def kb_add_prolog(): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl") return kb @@ -218,7 +218,7 @@ def kb_hwf2(): @pytest.fixture def kb_hed(): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return kb = HedKB( pseudo_label_list=[1, 0, "+", "="], diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index 5c602b4..d3a08b6 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -28,9 +28,13 @@ class TestKBBase(object): assert result == ([[0, 2], [1, 1], [2, 0]], [2, 2, 2]) def test_abduce_candidates(self, kb_add): - result = kb_add.abduce_candidates([0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0) + result = kb_add.abduce_candidates( + [0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0 + ) assert result == ([[0, 1]], [1]) - result = kb_add.abduce_candidates([1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0) + result = kb_add.abduce_candidates( + [1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0 + ) assert result == ([[1, 0]], [1]) @@ -53,19 +57,19 @@ class TestGroundKB(object): class TestPrologKB(object): def test_init_pl1(self, kb_add_prolog): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return assert kb_add_prolog.pseudo_label_list == list(range(10)) assert kb_add_prolog.pl_file == "examples/mnist_add/add.pl" def test_init_pl2(self, kb_hed): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return assert kb_hed.pseudo_label_list == [1, 0, "+", "="] assert kb_hed.pl_file == "examples/hed/reasoning/learn_add.pl" def test_prolog_file_not_exist(self): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return pseudo_label_list = [1, 2] non_existing_file = "path/to/non_existing_file.pl" @@ -74,13 +78,13 @@ class TestPrologKB(object): assert non_existing_file in str(excinfo.value) def test_logic_forward_pl1(self, kb_add_prolog): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return result = kb_add_prolog.logic_forward([1, 2]) assert result == 3 def test_logic_forward_pl2(self, kb_hed): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return consist_exs = [ [1, 1, "+", 0, "=", 1, 1], @@ -97,7 +101,7 @@ class TestPrologKB(object): assert kb_hed.logic_forward(inconsist_exs) is False def test_revise_at_idx(self, kb_add_prolog): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return result = kb_add_prolog.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0]) assert result == ([[0, 2]], [2]) @@ -113,34 +117,34 @@ class TestReaonser(object): assert 'Valid options for predefined dist_func include "hamming" and "confidence"' in str( excinfo.value ) - + def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results): cost_list = [np.random.rand() for _ in candidates] return cost_list - + def test_user_defined_dist_func(self, kb_add): reasoner = Reasoner(kb_add, self.random_dist) assert reasoner.dist_func == self.random_dist - + def invalid_dist1(self, candidates): cost_list = np.array([np.random.rand() for _ in candidates]) return cost_list - + def invalid_dist2(self, data_example, candidates, candidate_idxs, reasoning_results): cost_list = np.array([np.random.rand() for _ in candidates]) return np.append(cost_list, np.random.rand()) - + def test_invalid_user_defined_dist_func(self, kb_add, data_examples_add): with pytest.raises(ValueError) as excinfo: Reasoner(kb_add, self.invalid_dist1) - assert 'User-defined dist_func must have exactly four parameters' in str( - excinfo.value - ) + assert "User-defined dist_func must have exactly four parameters" in str(excinfo.value) with pytest.raises(ValueError) as excinfo: reasoner = Reasoner(kb_add, self.invalid_dist2) reasoner.batch_abduce(data_examples_add) - assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str( - excinfo.value + assert ( + "The length of the array returned by dist_func must be " + + "equal to the number of candidates" + in str(excinfo.value) ) @@ -186,7 +190,7 @@ class TestBatchAbduce(object): ] def test_batch_abduce_prolog(self, kb_add_prolog, data_examples_add): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1) @@ -208,7 +212,7 @@ class TestBatchAbduce(object): ] def test_batch_abduce_zoopt(self, kb_add_prolog, data_examples_add): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": return reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)