Merge Dev into Main: ABL-Package Overhaul with Enhanced Documentation and Examples - Completed a comprehensive rewrite of the ABL-Package, easy to use! - Incorporated detailed docstrings for each function, class, and module to provide clear, in-line documentation and facilitate code understanding. - Expanded the documentation, ensuring thorough understanding and ease of use for users. - Provided examples demonstrating the ABL-Package's capabilities, common use cases, and benchmarks. - Unit tests have been added to cover the updated ABL-Package. - Thanks to Wen-Chao Hu and En-Hao Gao for their dedication to improving code quality, writing tests, examples and documentation.main
| @@ -0,0 +1,14 @@ | |||
| [report] | |||
| show_missing = True | |||
| [run] | |||
| disable_warnings = include-ignored | |||
| include = */abl/* | |||
| omit = | |||
| */abl/__init__.py | |||
| abl/bridge/__init__.py | |||
| abl/dataset/__init__.py | |||
| abl/data/__init__.py | |||
| abl/learning/__init__.py | |||
| abl/reasoning/__init__.py | |||
| abl/utils/__init__.py | |||
| @@ -0,0 +1,62 @@ | |||
| name: ABL-Package-CI | |||
| on: | |||
| push: | |||
| branches: [ main ] | |||
| pull_request: | |||
| branches: [ main ] | |||
| jobs: | |||
| build: | |||
| runs-on: ${{ matrix.os }} | |||
| strategy: | |||
| matrix: | |||
| os: [ubuntu-latest, windows-latest, macos-latest] | |||
| python-version: ['3.7', '3.11'] | |||
| steps: | |||
| - uses: actions/checkout@v2 | |||
| - name: Set up Python | |||
| uses: actions/setup-python@v2 | |||
| with: | |||
| python-version: ${{ matrix.python-version }} | |||
| - name: Cache Python virtual environment | |||
| uses: actions/cache@v2 | |||
| with: | |||
| path: | | |||
| ~/.venv | |||
| !~/.venv/*/lib/python*/no-global-site-packages.txt | |||
| key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ hashFiles('**/requirements.txt') }} | |||
| restore-keys: | | |||
| ${{ runner.os }}-python-${{ matrix.python-version }}- | |||
| - name: Display python version | |||
| run: python -c "import sys; print(sys.version)" | |||
| - name: Install SWI-Prolog on Ubuntu | |||
| if: matrix.os == 'ubuntu-latest' | |||
| run: sudo apt-get install swi-prolog | |||
| - name: Install SWI-Prolog on Windows | |||
| if: matrix.os == 'windows-latest' | |||
| run: choco install swi-prolog | |||
| - name: Install SWI-Prolog on MACOS | |||
| if: matrix.os == 'macos-latest' | |||
| run: brew install swi-prolog | |||
| - name: Install package dependencies | |||
| run: | | |||
| python -m pip install --upgrade pip | |||
| pip install pytest pytest-cov | |||
| - name: Install | |||
| run: pip install -v -e . | |||
| - name: Run tests | |||
| run: | | |||
| pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests | |||
| - name: Publish code coverage | |||
| uses: codecov/codecov-action@v1 | |||
| with: | |||
| token: ${{ secrets.CODECOV_TOKEN }} | |||
| file: ./coverage.xml | |||
| @@ -0,0 +1,24 @@ | |||
| name: flake8 Lint | |||
| on: | |||
| push: | |||
| branches: [ main ] | |||
| pull_request: | |||
| branches: [ main ] | |||
| jobs: | |||
| flake8-lint: | |||
| runs-on: ubuntu-latest | |||
| name: Lint | |||
| steps: | |||
| - name: Check out source repository | |||
| uses: actions/checkout@v3 | |||
| - name: Set up Python environment | |||
| uses: actions/setup-python@v4 | |||
| with: | |||
| python-version: "3.8" | |||
| - name: flake8 Lint | |||
| uses: py-actions/flake8@v2 | |||
| with: | |||
| max-line-length: "100" | |||
| args: --ignore=E203,W503,F821,E266 | |||
| @@ -0,0 +1,14 @@ | |||
| *.pyc | |||
| examples/**/*.png | |||
| *.pk | |||
| *.pth | |||
| *.json | |||
| *.ckpt | |||
| results | |||
| raw/ | |||
| abl.egg-info/ | |||
| examples/**/*.jpg | |||
| .idea/ | |||
| build/ | |||
| docs/API/generated/ | |||
| .history | |||
| @@ -0,0 +1,21 @@ | |||
| MIT License | |||
| Copyright (c) 2024 LAMDA | |||
| Permission is hereby granted, free of charge, to any person obtaining a copy | |||
| of this software and associated documentation files (the "Software"), to deal | |||
| in the Software without restriction, including without limitation the rights | |||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||
| copies of the Software, and to permit persons to whom the Software is | |||
| furnished to do so, subject to the following conditions: | |||
| The above copyright notice and this permission notice shall be included in all | |||
| copies or substantial portions of the Software. | |||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
| SOFTWARE. | |||
| @@ -1,24 +0,0 @@ | |||
| # ABL Package | |||
| This is the code repository of abductive learning Package. | |||
| ## Environment dependency | |||
| ... | |||
| ## Example | |||
| share_example.py and nonshare_exaple.py are examples of grounded abductive learning. | |||
| ```bash | |||
| python share_example.py | |||
| ``` | |||
| ## Authors | |||
| - [Yu-Xuan Huang](http://www.lamda.nju.edu.cn/huangyx/) (Nanjing University) | |||
| - [](http://www.lamda.nju.edu.cn//) (Nanjing University) | |||
| ## NOTICE | |||
| They can only be used for academic purpose. For other purposes, please contact with LAMDA Group(www.lamda.nju.edu.cn). | |||
| @@ -0,0 +1,173 @@ | |||
| [](https://github.com/AbductiveLearning/ABL-Package/blob/Dev/LICENSE) | |||
| [](https://github.com/AbductiveLearning/ABL-Package/actions/workflows/lint.yaml) | |||
| [](https://github.com/psf/black) | |||
| [](https://github.com/AbductiveLearning/ABL-Package/actions/workflows/build-and-test.yaml) | |||
| # ABL-Package | |||
| **ABL-Package** is an open source library for **Abductive Learning (ABL)**. | |||
| ABL is a novel paradigm that integrates machine learning and | |||
| logical reasoning in a unified framework. It is suitable for tasks | |||
| where both data and (logical) domain knowledge are available. | |||
| Key Features of ABL-Package: | |||
| - **Great Flexibility**: Adaptable to various machine learning modules and logical reasoning components. | |||
| - **User-Friendly**: Provide data, model, and KB, and get started with just a few lines of code. | |||
| - **High-Performance**: Optimization for high accuracy and fast training speed. | |||
| ABL-Package encapsulates advanced ABL techniques, providing users with | |||
| an efficient and convenient package to develop dual-driven ABL systems, | |||
| which leverage the power of both data and knowledge. | |||
| To learn how to use it, please refer to - [document](https://www.lamda.nju.edu.cn/abl_test/docs/build/html/index.html). | |||
| ## Installation | |||
| ABL is distributed on [PyPI](https://pypi.org/) and can be installed with ``pip``: | |||
| ```bash | |||
| # (TODO) | |||
| $ pip install abl | |||
| ``` | |||
| For testing purposes, you can install it using: | |||
| ```bash | |||
| $ pip install -i https://test.pypi.org/simple/ --extra-index-url https://mirrors.nju.edu.cn/pypi/web/simple/ abl | |||
| ``` | |||
| Alternatively, to install ABL by source code, sequentially run following commands in your terminal/command line. | |||
| ```bash | |||
| $ git clone https://github.com/AbductiveLearning/ABL-Package.git | |||
| $ cd ABL-Package | |||
| $ pip install -v -e . | |||
| ``` | |||
| (Optional) If the use of a [Prolog-based knowledge base](https://www.lamda.nju.edu.cn/abl_test/docs/build/html/Intro/Reasoning.html#prolog) is necessary, the installation of [Swi-Prolog](https://www.swi-prolog.org/) is also required: | |||
| For Linux users: | |||
| ```bash | |||
| $ sudo apt-get install swi-prolog | |||
| ``` | |||
| For Windows and Mac users, please refer to the [Swi-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). | |||
| ## Examples | |||
| We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file. | |||
| + [MNIST Addition](https://github.com/AbductiveLearning/ABL-Package/blob/Dev/examples/mnist_add) | |||
| + [Handwritten Formula](https://github.com/AbductiveLearning/ABL-Package/blob/Dev/examples/hwf) | |||
| + [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABL-Package/tree/Dev/examples/hed) | |||
| + [Zoo](https://github.com/AbductiveLearning/ABL-Package/tree/Dev/examples/zoo) | |||
| ## Quick Start | |||
| We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contain information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. | |||
| ### Working with Data | |||
| ABL-Package requires data in the format of `(X, gt_pseudo_label, Y)` where `X` is a list of input examples containing instances, `gt_pseudo_label` is the ground-truth label of each example in `X` and `Y` is the ground-truth reasoning result of each example in `X`. Note that `gt_pseudo_label` is only used to evaluate the machine learning model's performance but not to train it. | |||
| In the MNIST Addition task, the data loading looks like: | |||
| ```python | |||
| # The 'datasets' module below is located in 'examples/mnist_add/' | |||
| from datasets import get_dataset | |||
| # train_data and test_data are tuples in the format of (X, gt_pseudo_label, Y) | |||
| train_data = get_dataset(train=True) | |||
| test_data = get_dataset(train=False) | |||
| ``` | |||
| ### Building the Learning Part | |||
| Learning part is constructed by first defining a base model for machine learning. The ABL-Package offers considerable flexibility, supporting any base model that conforms to the scikit-learn style (which requires the implementation of fit and predict methods), or a PyTorch-based neural network (which has defined the architecture and implemented forward method). In this example, we build a simple LeNet5 network as the base model. | |||
| ```python | |||
| # The 'models' module below is located in 'examples/mnist_add/' | |||
| from models.nn import LeNet5 | |||
| cls = LeNet5(num_classes=10) | |||
| ``` | |||
| To facilitate uniform processing, ABL-Package provides the `BasicNN` class to convert a PyTorch-based neural network into a format compatible with scikit-learn models. To construct a `BasicNN` instance, aside from the network itself, we also need to define a loss function, an optimizer, and the computing device. | |||
| ```python | |||
| import torch | |||
| from abl.learning import BasicNN | |||
| | |||
| loss_fn = torch.nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9) | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| base_model = BasicNN(model=cls, loss_fn=loss_fn, optimizer=optimizer, device=device) | |||
| ``` | |||
| The base model built above are trained to make predictions on instance-level data (e.g., a single image), while ABL deals with example-level data. To bridge this gap, we wrap the base_model into an instance of `ABLModel`. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise an equation). | |||
| ```python | |||
| from abl.learning import ABLModel | |||
| | |||
| model = ABLModel(base_model) | |||
| ``` | |||
| ### Building the Reasoning Part | |||
| To build the reasoning part, we first define a knowledge base by creating a subclass of `KBBase`. In the subclass, we initialize the `pseudo_label_list` parameter and override the `logic_forward` method, which specifies how to perform (deductive) reasoning that processes pseudo-labels of an example to the corresponding reasoning result. Specifically for the MNIST Addition task, this `logic_forward` method is tailored to execute the sum operation. | |||
| ```python | |||
| from abl.reasoning import KBBase | |||
| | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| | |||
| kb = AddKB() | |||
| ``` | |||
| Next, we create a reasoner by instantiating the class `Reasoner`, passing the knowledge base as a parameter. Due to the indeterminism of abductive reasoning, there could be multiple candidate pseudo-labels compatible to the knowledge base. In such scenarios, the reasoner can minimize inconsistency and return the pseudo-label with the highest consistency. | |||
| ```python | |||
| from abl.reasoning import Reasoner | |||
| | |||
| reasoner = Reasoner(kb) | |||
| ``` | |||
| ### Building Evaluation Metrics | |||
| ABL-Package provides two basic metrics, namely `SymbolAccuracy` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the `logic_forward` results, respectively. | |||
| ```python | |||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| | |||
| metric_list = [SymbolAccuracy(), ReasoningMetric(kb=kb)] | |||
| ``` | |||
| ### Bridging Learning and Reasoning | |||
| Now, we use `SimpleBridge` to combine learning and reasoning in a | |||
| unified ABL framework. | |||
| ```python | |||
| from abl.bridge import SimpleBridge | |||
| | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| ``` | |||
| Finally, we proceed with training and testing. | |||
| ```python | |||
| bridge.train(train_data, loops=1, segment_size=0.01) | |||
| bridge.test(test_data) | |||
| ``` | |||
| ## References | |||
| For more information about ABL, please refer to: [Zhou, 2019](https://link.springer.com/epdf/10.1007/s11432-018-9801-4?author_access_token=jgJe1Ox3Mk-K7ORSnX7jtfe4RwlQNchNByi7wbcMAY7_PxTx-xNLP7Lp0mIZ04ORp3VG4wioIBHSCIAO3B_TBJkj87YzapmdnYVSQvgBIO3aEpQWppxZG25KolINetygc2W_Cj2gtoBdiG_J1hU3pA==) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). | |||
| @@ -1,104 +0,0 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # Copyright (C) 2021 Freecss All rights reserved. | |||
| # | |||
| # File Name :abducer_base.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2021/06/03 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| import abc | |||
| from abducer.kb import ClsKB, RegKB | |||
| #from kb import ClsKB, RegKB | |||
| import numpy as np | |||
| def hamming_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def confidence_dist(A, B): | |||
| B = np.array(B) | |||
| #print(A) | |||
| A = np.clip(A, 1e-9, 1) | |||
| A = np.expand_dims(A, axis=0) | |||
| A = A.repeat(axis=0, repeats=(len(B))) | |||
| rows = np.array(range(len(B))) | |||
| rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) | |||
| cols = np.array(range(len(B[0]))) | |||
| cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis = 1) | |||
| class AbducerBase(abc.ABC): | |||
| def __init__(self, kb, dist_func = "hamming", pred_res_parse = None): | |||
| self.kb = kb | |||
| if dist_func == "hamming": | |||
| dist_func = hamming_dist | |||
| elif dist_func == "confidence": | |||
| dist_func = confidence_dist | |||
| self.dist_func = dist_func | |||
| if pred_res_parse is None: | |||
| pred_res_parse = lambda x : x["cls"] | |||
| self.pred_res_parse = pred_res_parse | |||
| def abduce(self, data, max_address_num, require_more_address, length = -1): | |||
| pred_res, ans = data | |||
| if length == -1: | |||
| length = len(pred_res) | |||
| candidates = self.kb.get_candidates(ans, length) | |||
| pred_res = np.array(pred_res) | |||
| cost_list = self.dist_func(pred_res, candidates) | |||
| address_num = np.min(cost_list) | |||
| threshold = min(address_num + require_more_address, max_address_num) | |||
| idxs = np.where(cost_list <= address_num+require_more_address)[0] | |||
| #return [candidates[idx] for idx in idxs], address_num | |||
| if len(idxs) > 1: | |||
| return None | |||
| return [candidates[idx] for idx in idxs][0] | |||
| def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0): | |||
| return [ | |||
| self.abduce((y, c), max_address_num, require_more_address)\ | |||
| for y, c in zip(self.pred_res_parse(Y), C) | |||
| ] | |||
| def __call__(self, Y, C, max_address_num = 3, require_more_address = 0): | |||
| return batch_abduce(Y, C, max_address_num, require_more_address) | |||
| if __name__ == "__main__": | |||
| #["1+1", "0+1", "1+0", "2+0"] | |||
| X = [[1,3,1], [0,3,1], [1,2,0], [3,2,0]] | |||
| Y = [2, 1, 1, 2] | |||
| kb = RegKB(X, Y) | |||
| abd = AbducerBase(kb) | |||
| res = abd.abduce(([0,2,0], None), 1, 0) | |||
| print(res) | |||
| res = abd.abduce(([0, 2, 0], 0.99), 1, 0) | |||
| print(res) | |||
| A = np.array([[0.5, 0.25, 0.25, 0], [0.3, 0.3, 0.3, 0.1], [0.1, 0.2, 0.3, 0.4]]) | |||
| B = [[1, 2, 3], [0, 1, 3]] | |||
| res = confidence_dist(A, B) | |||
| print(res) | |||
| A = np.array([[0.5, 0.25, 0.25, 0], [0.3, 1.0, 0.3, 0.1], [0.1, 0.2, 0.3, 1.0]]) | |||
| B = [[0, 1, 3]] | |||
| res = confidence_dist(A, B) | |||
| print(res) | |||
| kb_str = ['10010001011', '00010001100', '00111101011', '11101000011', '11110011001', '11111010001', '10001010010', '11100100001', '10001001100', '11011010001', '00110000100', '11000000111', '01110111111', '11000101100', '10101011010', '00000110110', '11111110010', '11100101100', '10111001111', '10000101100', '01001011101', '01001110000', '01110001110', '01010010001', '10000100010', '01001011011', '11111111100', '01011101101', '00101110101', '11101001101', '10010110000', '10000000011'] | |||
| X = [[int(c) for c in s] for s in kb_str] | |||
| kb = RegKB(X, len(X) * [None]) | |||
| abd = AbducerBase(kb) | |||
| res = abd.abduce(((1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1), None), 1, 0) | |||
| print(res) | |||
| @@ -1,137 +0,0 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # Copyright (C) 2021 LAMDA All rights reserved. | |||
| # | |||
| # File Name :kb.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2021/06/03 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| import abc | |||
| import bisect | |||
| import copy | |||
| import numpy as np | |||
| from collections import defaultdict | |||
| class KBBase(abc.ABC): | |||
| def __init__(self, X = None, Y = None): | |||
| pass | |||
| def get_candidates(self, key = None, length = None): | |||
| pass | |||
| def get_all_candidates(self): | |||
| pass | |||
| def _length(self, length): | |||
| if length is None: | |||
| length = list(self.base.keys()) | |||
| if type(length) is int: | |||
| length = [length] | |||
| return length | |||
| def __len__(self): | |||
| pass | |||
| class ClsKB(KBBase): | |||
| def __init__(self, X, Y = None): | |||
| super().__init__() | |||
| self.base = {} | |||
| if X is None: | |||
| return | |||
| if Y is None: | |||
| Y = [None] * len(X) | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| def get_candidates(self, key, length = None): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| return sum([self.base[l][key] for l in length], []) | |||
| def get_all_candidates(self): | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def _dict_len(self, dic): | |||
| return sum(len(c) for c in dic.values()) | |||
| def __len__(self): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class RegKB(KBBase): | |||
| def __init__(self, X, Y = None): | |||
| super().__init__() | |||
| tmp_dict = {} | |||
| for x, y in zip(X, Y): | |||
| tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| self.base = {} | |||
| for l in tmp_dict.keys(): | |||
| data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) | |||
| X = [x for y, x in data] | |||
| Y = [y for y, x in data] | |||
| self.base[l] = (X, Y) | |||
| def get_candidates(self, key, length = None): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| min_err = 999999 | |||
| candidates = [] | |||
| for l in length: | |||
| X, Y = self.base[l] | |||
| idx = bisect.bisect_left(Y, key) | |||
| begin = max(0, idx - 1) | |||
| end = min(idx + 2, len(X)) | |||
| for idx in range(begin, end): | |||
| err = abs(Y[idx] - key) | |||
| if abs(err - min_err) < 1e-9: | |||
| candidates.extend(X[idx]) | |||
| elif err < min_err: | |||
| candidates = copy.deepcopy(X[idx]) | |||
| min_err = err | |||
| return candidates | |||
| def get_all_candidates(self): | |||
| return sum([sum(D[0], []) for D in self.base.values()], []) | |||
| def __len__(self): | |||
| return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) | |||
| if __name__ == "__main__": | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
| Y = [2, 1, 1, 2, 2] | |||
| kb = ClsKB(X, Y) | |||
| print(len(kb)) | |||
| res = kb.get_candidates(2, 5) | |||
| print(res) | |||
| res = kb.get_candidates(2, 3) | |||
| print(res) | |||
| res = kb.get_candidates(None) | |||
| print(res) | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
| Y = [2, 1, 1, 2, 1.5, 1.5] | |||
| kb = RegKB(X, Y) | |||
| print(len(kb)) | |||
| res = kb.get_candidates(1.6) | |||
| print(res) | |||
| res = kb.get_candidates(1.6, length = 9) | |||
| print(res) | |||
| res = kb.get_candidates(None) | |||
| print(res) | |||
| @@ -0,0 +1,9 @@ | |||
| from . import bridge, data, learning, reasoning, utils | |||
| __all__ = [ | |||
| "bridge", | |||
| "data", | |||
| "learning", | |||
| "reasoning", | |||
| "utils", | |||
| ] | |||
| @@ -0,0 +1,4 @@ | |||
| from .base_bridge import BaseBridge | |||
| from .simple_bridge import SimpleBridge | |||
| __all__ = ["BaseBridge", "SimpleBridge"] | |||
| @@ -0,0 +1,89 @@ | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Optional, Tuple, Union | |||
| from ..data.structures import ListData | |||
| from ..learning import ABLModel | |||
| from ..reasoning import Reasoner | |||
| class BaseBridge(metaclass=ABCMeta): | |||
| """ | |||
| A base class for bridging learning and reasoning parts. | |||
| This class provides necessary methods that need to be overridden in subclasses | |||
| to construct a typical pipeline of Abductive Learning (corresponding to ``train``), | |||
| which involves the following four methods: | |||
| - predict: Predict class indices on the given data examples. | |||
| - idx_to_pseudo_label: Map indices into pseudo-labels. | |||
| - abduce_pseudo_label: Revise pseudo-labels based on abdutive reasoning. | |||
| - pseudo_label_to_idx: Map revised pseudo-labels back into indices. | |||
| Parameters | |||
| ---------- | |||
| model : ABLModel | |||
| The machine learning model wrapped in ``ABLModel``, which is mainly used for | |||
| prediction and model training. | |||
| reasoner : Reasoner | |||
| The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision. | |||
| """ | |||
| def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: | |||
| if not isinstance(model, ABLModel): | |||
| raise TypeError( | |||
| "Expected an instance of ABLModel, but received type: {}".format(type(model)) | |||
| ) | |||
| if not isinstance(reasoner, Reasoner): | |||
| raise TypeError( | |||
| "Expected an instance of Reasoner, but received type: {}".format(type(reasoner)) | |||
| ) | |||
| self.model = model | |||
| self.reasoner = reasoner | |||
| @abstractmethod | |||
| def predict(self, data_examples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
| """Placeholder for predicting class indices from input.""" | |||
| @abstractmethod | |||
| def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
| """Placeholder for revising pseudo-labels based on abdutive reasoning.""" | |||
| @abstractmethod | |||
| def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
| """Placeholder for mapping indices to pseudo-labels.""" | |||
| @abstractmethod | |||
| def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: | |||
| """Placeholder for mapping pseudo-labels to indices.""" | |||
| def filter_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
| """Default filter function for pseudo-label.""" | |||
| non_empty_idx = [ | |||
| i | |||
| for i in range(len(data_examples.abduced_pseudo_label)) | |||
| if data_examples.abduced_pseudo_label[i] | |||
| ] | |||
| data_examples.update(data_examples[non_empty_idx]) | |||
| return data_examples | |||
| @abstractmethod | |||
| def train( | |||
| self, | |||
| train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], | |||
| ): | |||
| """Placeholder for training loop of ABductive Learning.""" | |||
| @abstractmethod | |||
| def valid( | |||
| self, | |||
| val_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], | |||
| ) -> None: | |||
| """Placeholder for model test.""" | |||
| @abstractmethod | |||
| def test( | |||
| self, | |||
| test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], | |||
| ) -> None: | |||
| """Placeholder for model validation.""" | |||
| @@ -0,0 +1,356 @@ | |||
| import os.path as osp | |||
| from typing import Any, List, Optional, Tuple, Union | |||
| from numpy import ndarray | |||
| from ..data.evaluation import BaseMetric | |||
| from ..data.structures import ListData | |||
| from ..learning import ABLModel | |||
| from ..reasoning import Reasoner | |||
| from ..utils import print_log | |||
| from .base_bridge import BaseBridge | |||
| class SimpleBridge(BaseBridge): | |||
| """ | |||
| A basic implementation for bridging machine learning and reasoning parts. | |||
| This class implements the typical pipeline of Abductive Learning, which involves | |||
| the following five steps: | |||
| - Predict class probabilities and indices for the given data examples. | |||
| - Map indices into pseudo-labels. | |||
| - Revise pseudo-labels based on abdutive reasoning. | |||
| - Map the revised pseudo-labels to indices. | |||
| - Train the model. | |||
| Parameters | |||
| ---------- | |||
| model : ABLModel | |||
| The machine learning model wrapped in ``ABLModel``, which is mainly used for | |||
| prediction and model training. | |||
| reasoner : Reasoner | |||
| The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision. | |||
| metric_list : List[BaseMetric] | |||
| A list of metrics used for evaluating the model's performance. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| model: ABLModel, | |||
| reasoner: Reasoner, | |||
| metric_list: List[BaseMetric], | |||
| ) -> None: | |||
| super().__init__(model, reasoner) | |||
| self.metric_list = metric_list | |||
| def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
| """ | |||
| Predict class indices and probabilities (if ``predict_proba`` is implemented in | |||
| ``self.model.base_model``) on the given data examples. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| Data examples on which predictions are to be made. | |||
| Returns | |||
| ------- | |||
| Tuple[List[ndarray], List[ndarray]] | |||
| A tuple containing lists of predicted indices and probabilities. | |||
| """ | |||
| self.model.predict(data_examples) | |||
| return data_examples.pred_idx, data_examples.pred_prob | |||
| def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
| """ | |||
| Revise predicted pseudo-labels of the given data examples using abduction. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| Data examples containing predicted pseudo-labels. | |||
| Returns | |||
| ------- | |||
| List[List[Any]] | |||
| A list of abduced pseudo-labels for the given data examples. | |||
| """ | |||
| self.reasoner.batch_abduce(data_examples) | |||
| return data_examples.abduced_pseudo_label | |||
| def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
| """ | |||
| Map indices of data examples into pseudo-labels. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| Data examples containing the indices. | |||
| Returns | |||
| ------- | |||
| List[List[Any]] | |||
| A list of pseudo-labels converted from indices. | |||
| """ | |||
| pred_idx = data_examples.pred_idx | |||
| data_examples.pred_pseudo_label = [ | |||
| [self.reasoner.idx_to_label[_idx] for _idx in sub_list] for sub_list in pred_idx | |||
| ] | |||
| return data_examples.pred_pseudo_label | |||
| def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: | |||
| """ | |||
| Map pseudo-labels of data examples into indices. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| Data examples containing pseudo-labels. | |||
| Returns | |||
| ------- | |||
| List[List[Any]] | |||
| A list of indices converted from pseudo-labels. | |||
| """ | |||
| abduced_idx = [ | |||
| [ | |||
| self.reasoner.label_to_idx[_abduced_pseudo_label] | |||
| for _abduced_pseudo_label in sub_list | |||
| ] | |||
| for sub_list in data_examples.abduced_pseudo_label | |||
| ] | |||
| data_examples.abduced_idx = abduced_idx | |||
| return data_examples.abduced_idx | |||
| def data_preprocess( | |||
| self, | |||
| prefix: str, | |||
| data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], | |||
| ) -> ListData: | |||
| """ | |||
| Transform data in the form of (X, gt_pseudo_label, Y) into ListData. | |||
| Parameters | |||
| ---------- | |||
| prefix : str | |||
| A prefix indicating the type of data processing (e.g., 'train', 'test'). | |||
| data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] | |||
| Data to be preprocessed. Can be ListData or a tuple of lists. | |||
| Returns | |||
| ------- | |||
| ListData | |||
| The preprocessed ListData object. | |||
| """ | |||
| if isinstance(data, ListData): | |||
| data_examples = data | |||
| if not ( | |||
| hasattr(data_examples, "X") | |||
| and hasattr(data_examples, "gt_pseudo_label") | |||
| and hasattr(data_examples, "Y") | |||
| ): | |||
| raise ValueError( | |||
| f"{prefix}data should have X, gt_pseudo_label and Y attribute but " | |||
| f"only {data_examples.all_keys()} are provided." | |||
| ) | |||
| else: | |||
| X, gt_pseudo_label, Y = data | |||
| data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
| return data_examples | |||
| def concat_data_examples( | |||
| self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData] | |||
| ) -> ListData: | |||
| """ | |||
| Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data | |||
| examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model. | |||
| Parameters | |||
| ---------- | |||
| unlabel_data_examples : ListData | |||
| Unlabeled data examples to concatenate. | |||
| label_data_examples : ListData, optional | |||
| Labeled data examples to concatenate, if available. | |||
| Returns | |||
| ------- | |||
| ListData | |||
| Concatenated data examples. | |||
| """ | |||
| if label_data_examples is None: | |||
| return unlabel_data_examples | |||
| unlabel_data_examples.X = unlabel_data_examples.X + label_data_examples.X | |||
| unlabel_data_examples.abduced_pseudo_label = ( | |||
| unlabel_data_examples.abduced_pseudo_label + label_data_examples.gt_pseudo_label | |||
| ) | |||
| unlabel_data_examples.Y = unlabel_data_examples.Y + label_data_examples.Y | |||
| return unlabel_data_examples | |||
| def train( | |||
| self, | |||
| train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], | |||
| label_data: Optional[ | |||
| Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]] | |||
| ] = None, | |||
| val_data: Optional[ | |||
| Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] | |||
| ] = None, | |||
| loops: int = 50, | |||
| segment_size: Union[int, float] = 1.0, | |||
| eval_interval: int = 1, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| ): | |||
| """ | |||
| A typical training pipeline of Abuductive Learning. | |||
| Parameters | |||
| ---------- | |||
| train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] | |||
| Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` | |||
| object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. | |||
| - ``X`` is a list of sublists representing the input data. | |||
| - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but | |||
| not to train. ``gt_pseudo_label`` can be ``None``. | |||
| - ``Y`` is a list representing the ground truth reasoning result for each sublist | |||
| in ``X``. | |||
| label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional | |||
| Labeled data should be in the same format as ``train_data``. The only difference is | |||
| that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be | |||
| utilized to train the model. Defaults to None. | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 | |||
| Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label`` | |||
| and ``Y`` can be either None or not, which depends on the evaluation metircs in | |||
| ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate | |||
| the model during training time. Defaults to None. | |||
| loops : int | |||
| Machine Learning part and Reasoning part will be iteratively optimized | |||
| for ``loops`` times, by default 50. | |||
| segment_size : Union[int, float] | |||
| Data will be split into segments of this size and data in each segment | |||
| will be used together to train the model, by default 1.0. | |||
| eval_interval : int | |||
| The model will be evaluated every ``eval_interval`` loops during training, | |||
| by default 1. | |||
| save_interval : int, optional | |||
| The model will be saved every ``eval_interval`` loops during training, by | |||
| default None. | |||
| save_dir : str, optional | |||
| Directory to save the model, by default None. | |||
| """ | |||
| data_examples = self.data_preprocess("train", train_data) | |||
| if label_data is not None: | |||
| label_data_examples = self.data_preprocess("label", label_data) | |||
| else: | |||
| label_data_examples = None | |||
| if val_data is not None: | |||
| val_data_examples = self.data_preprocess("val", val_data) | |||
| else: | |||
| val_data_examples = data_examples | |||
| if isinstance(segment_size, int): | |||
| if segment_size <= 0: | |||
| raise ValueError("segment_size should be positive.") | |||
| elif isinstance(segment_size, float): | |||
| if 0 < segment_size <= 1: | |||
| segment_size = int(segment_size * len(data_examples)) | |||
| else: | |||
| raise ValueError("segment_size should be in (0, 1].") | |||
| else: | |||
| raise ValueError("segment_size should be int or float.") | |||
| for loop in range(loops): | |||
| for seg_idx in range((len(data_examples) - 1) // segment_size + 1): | |||
| print_log( | |||
| f"loop(train) [{loop + 1}/{loops}] segment(train) " | |||
| f"[{(seg_idx + 1)}/{(len(data_examples) - 1) // segment_size + 1}] ", | |||
| logger="current", | |||
| ) | |||
| sub_data_examples = data_examples[ | |||
| seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
| ] | |||
| self.predict(sub_data_examples) | |||
| self.idx_to_pseudo_label(sub_data_examples) | |||
| self.abduce_pseudo_label(sub_data_examples) | |||
| self.filter_pseudo_label(sub_data_examples) | |||
| self.concat_data_examples(sub_data_examples, label_data_examples) | |||
| self.pseudo_label_to_idx(sub_data_examples) | |||
| self.model.train(sub_data_examples) | |||
| if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||
| print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") | |||
| self._valid(val_data_examples) | |||
| if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||
| print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||
| self.model.save( | |||
| save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") | |||
| ) | |||
| def _valid(self, data_examples: ListData) -> None: | |||
| """ | |||
| Internal method for validating the model with given data examples. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| Data examples to be used for validation. | |||
| """ | |||
| self.predict(data_examples) | |||
| self.idx_to_pseudo_label(data_examples) | |||
| for metric in self.metric_list: | |||
| metric.process(data_examples) | |||
| res = dict() | |||
| for metric in self.metric_list: | |||
| res.update(metric.evaluate()) | |||
| msg = "Evaluation ended, " | |||
| for k, v in res.items(): | |||
| msg += k + f": {v:.3f} " | |||
| print_log(msg, logger="current") | |||
| def valid( | |||
| self, | |||
| val_data: Union[ | |||
| ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] | |||
| ], | |||
| ) -> None: | |||
| """ | |||
| Validate the model with the given validation data. | |||
| Parameters | |||
| ---------- | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
| Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
| with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be | |||
| either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
| """ | |||
| val_data_examples = self.data_preprocess("val", val_data) | |||
| self._valid(val_data_examples) | |||
| def test( | |||
| self, | |||
| test_data: Union[ | |||
| ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] | |||
| ], | |||
| ) -> None: | |||
| """ | |||
| Test the model with the given test data. | |||
| Parameters | |||
| ---------- | |||
| test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
| Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
| with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` | |||
| can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
| """ | |||
| print_log("Test start:", logger="current") | |||
| test_data_examples = self.data_preprocess("test", test_data) | |||
| self._valid(test_data_examples) | |||
| @@ -0,0 +1,4 @@ | |||
| from .evaluation import BaseMetric, ReasoningMetric, SymbolAccuracy | |||
| from .structures import ListData | |||
| __all__ = ["BaseMetric", "ReasoningMetric", "SymbolAccuracy", "ListData"] | |||
| @@ -0,0 +1,141 @@ | |||
| from typing import Any, Tuple | |||
| from abl.utils import tab_data_to_tuple | |||
| from .structures.list_data import ListData | |||
| from lambdaLearn.Base.TabularMixin import TabularMixin | |||
| class DataConverter: | |||
| """ | |||
| This class provides functionality to convert LambdaLearn data to ABL-Package data. | |||
| """ | |||
| def __init__(self) -> None: | |||
| pass | |||
| def convert_lambdalearn_to_tuple( | |||
| self, dataset: TabularMixin, reasoning_result: Any | |||
| ) -> Tuple[Tuple, Tuple, Tuple, Tuple]: | |||
| """ | |||
| Convert a lambdalearn dataset to a tuple of tuples (label_data, train_data, valid_data, test_data), # noqa: E501 | |||
| each containing (data, label, reasoning_result). | |||
| Parameters | |||
| ---------- | |||
| dataset : TabularMixin | |||
| The LambdaLearn dataset to be converted. | |||
| reasoning_result : Any | |||
| The reasoning result of the dataset. | |||
| Returns | |||
| ------- | |||
| Tuple[Tuple, Tuple, Tuple, Tuple] | |||
| A tuple of (label_data, train_data, valid_data, test_data), where each element is | |||
| a tuple of (data, label, reasoning_result). | |||
| """ | |||
| if not isinstance(dataset, TabularMixin): | |||
| raise NotImplementedError( | |||
| "Only support converting the datasets that are instances of TabularMixin. " | |||
| + "Please refer to the documentation and manually convert the dataset into a tuple." | |||
| ) | |||
| label_data = tab_data_to_tuple( | |||
| dataset.labeled_X, dataset.labeled_y, reasoning_result=reasoning_result | |||
| ) | |||
| train_data = tab_data_to_tuple( | |||
| dataset.unlabeled_X, dataset.unlabeled_y, reasoning_result=reasoning_result | |||
| ) | |||
| valid_data = tab_data_to_tuple( | |||
| dataset.valid_X, dataset.valid_y, reasoning_result=reasoning_result | |||
| ) | |||
| test_data = tab_data_to_tuple( | |||
| dataset.test_X, dataset.test_y, reasoning_result=reasoning_result | |||
| ) | |||
| return label_data, train_data, valid_data, test_data | |||
| def convert_lambdalearn_to_listdata( | |||
| self, dataset: TabularMixin, reasoning_result: Any | |||
| ) -> Tuple[ListData, ListData, ListData, ListData]: | |||
| """ | |||
| Convert a lambdalearn dataset to a tuple of ListData | |||
| (label_data_examples, train_data_examples, valid_data_examples, test_data_examples). | |||
| Parameters | |||
| ---------- | |||
| dataset : TabularMixin | |||
| The LambdaLearn dataset to be converted. | |||
| reasoning_result : Any | |||
| The reasoning result of the dataset. | |||
| Returns | |||
| ------- | |||
| Tuple[ListData, ListData, ListData, ListData] | |||
| A tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples) # noqa: E501 | |||
| """ | |||
| if not isinstance(dataset, TabularMixin): | |||
| raise NotImplementedError( | |||
| "Only support converting the datasets that are instances of TabularMixin. " | |||
| + "Please refer to the documentation and manually convert the dataset " | |||
| + "into a ListData." | |||
| ) | |||
| label_data, train_data, valid_data, test_data = self.convert_lambdalearn_to_tuple( | |||
| dataset, reasoning_result | |||
| ) | |||
| if label_data is not None: | |||
| X, gt_pseudo_label, Y = label_data | |||
| label_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
| if train_data is not None: | |||
| X, gt_pseudo_label, Y = train_data | |||
| train_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
| if valid_data is not None: | |||
| X, gt_pseudo_label, Y = valid_data | |||
| valid_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
| if test_data is not None: | |||
| X, gt_pseudo_label, Y = test_data | |||
| test_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
| return label_data_examples, train_data_examples, valid_data_examples, test_data_examples | |||
| if __name__ == "__main__": | |||
| from lambdaLearn.Dataset.Tabular.BreastCancer import BreastCancer | |||
| breast_dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True) | |||
| dataconverter = DataConverter() | |||
| label_data, train_data, valid_data, test_data = dataconverter.convert_lambdalearn_to_tuple( | |||
| breast_dataset, 0 | |||
| ) | |||
| print( | |||
| type(label_data).__name__, | |||
| type(train_data).__name__, | |||
| type(valid_data).__name__, | |||
| type(test_data).__name__, | |||
| ) | |||
| print(len(label_data)) | |||
| print(len(label_data[0]), len(label_data[1]), len(label_data[2])) | |||
| print(label_data[0][0], label_data[1][0], label_data[2][0]) | |||
| print() | |||
| ( | |||
| label_data_examples, | |||
| train_data_examples, | |||
| valid_data_examples, | |||
| test_data_examples, | |||
| ) = dataconverter.convert_lambdalearn_to_listdata(breast_dataset, 0) | |||
| print( | |||
| type(label_data_examples).__name__, | |||
| type(train_data_examples).__name__, | |||
| type(valid_data_examples).__name__, | |||
| type(test_data_examples).__name__, | |||
| ) | |||
| print( | |||
| len(label_data_examples.X), | |||
| len(label_data_examples.gt_pseudo_label), | |||
| len(label_data_examples.Y), | |||
| ) | |||
| label_data_example = label_data_examples[0] | |||
| print(label_data_example.X, label_data_example.gt_pseudo_label, label_data_example.Y) | |||
| @@ -0,0 +1,5 @@ | |||
| from .base_metric import BaseMetric | |||
| from .reasoning_metric import ReasoningMetric | |||
| from .symbol_accuracy import SymbolAccuracy | |||
| __all__ = ["BaseMetric", "ReasoningMetric", "SymbolAccuracy"] | |||
| @@ -0,0 +1,85 @@ | |||
| import logging | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Optional | |||
| from ...utils import print_log | |||
| from ..structures import ListData | |||
| class BaseMetric(metaclass=ABCMeta): | |||
| """ | |||
| Base class for a metrics. | |||
| The metrics first processes each batch of data_examples and appends the processed | |||
| results to the results list. Then, it computes the metrics of the entire dataset. | |||
| Parameters | |||
| ---------- | |||
| prefix : str, optional | |||
| The prefix that will be added in the metrics names to disambiguate homonymous | |||
| metrics of different tasks. If prefix is not provided in the argument, | |||
| self.default_prefix will be used instead. Default to None. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| prefix: Optional[str] = None, | |||
| ) -> None: | |||
| self.default_prefix = "" | |||
| self.results: List[Any] = [] | |||
| self.prefix = prefix or self.default_prefix | |||
| @abstractmethod | |||
| def process(self, data_examples: ListData) -> None: | |||
| """ | |||
| Process one batch of data examples. The processed results should be stored | |||
| in ``self.results``, which will be used to compute the metrics when all | |||
| batches have been processed. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| A batch of data examples. | |||
| """ | |||
| @abstractmethod | |||
| def compute_metrics(self) -> dict: | |||
| """ | |||
| Compute the metrics from processed results. | |||
| Returns | |||
| ------- | |||
| dict | |||
| The computed metrics. The keys are the names of the metrics, | |||
| and the values are corresponding results. | |||
| """ | |||
| def evaluate(self) -> dict: | |||
| """ | |||
| Evaluate the model performance of the whole dataset after processing | |||
| all batches. | |||
| Returns | |||
| ------- | |||
| dict | |||
| Evaluation metrics dict on the val dataset. The keys are the | |||
| names of the metrics, and the values are corresponding results. | |||
| """ | |||
| if len(self.results) == 0: | |||
| print_log( | |||
| f"{self.__class__.__name__} got empty `self.results`. Please " | |||
| "ensure that the processed results are properly added into " | |||
| "`self.results` in `process` method.", | |||
| logger="current", | |||
| level=logging.WARNING, | |||
| ) | |||
| metrics = self.compute_metrics() | |||
| # Add prefix to metrics names | |||
| if self.prefix: | |||
| metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()} | |||
| # reset the results list | |||
| self.results.clear() | |||
| return metrics | |||
| @@ -0,0 +1,79 @@ | |||
| from typing import Optional | |||
| from ...reasoning import KBBase | |||
| from ..structures import ListData | |||
| from .base_metric import BaseMetric | |||
| class ReasoningMetric(BaseMetric): | |||
| """ | |||
| A metrics class for evaluating the model performance on tasks need reasoning. | |||
| This class is designed to calculate the accuracy of the reasoing results. Reasoning | |||
| results are generated by first using the learning part to predict pseudo-labels | |||
| and then using a knowledge base (KB) to perform logical reasoning. The reasoning results | |||
| are then compared with the ground truth to calculate the accuracy. | |||
| Parameters | |||
| ---------- | |||
| kb : KBBase | |||
| An instance of a knowledge base, used for logical reasoning and validation. | |||
| If not provided, reasoning checks are not performed. Default to None. | |||
| prefix : str, optional | |||
| The prefix that will be added to the metrics names to disambiguate homonymous | |||
| metrics of different tasks. Inherits from BaseMetric. Default to None. | |||
| Notes | |||
| ----- | |||
| The `ReasoningMetric` expects data_examples to have the attributes `pred_pseudo_label`, | |||
| `Y`, and `X`, corresponding to the predicted pseduo labels, ground truth of reasoning | |||
| results, and input data, respectively. | |||
| """ | |||
| def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | |||
| super().__init__(prefix) | |||
| self.kb = kb | |||
| def process(self, data_examples: ListData) -> None: | |||
| """ | |||
| Process a batch of data examples. | |||
| This method takes in a batch of data examples, each containing predicted pseudo-labels | |||
| (pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It | |||
| evaluates the reasoning accuracy of each example by comparing the logical reasoning | |||
| result (derived using the knowledge base) of the predicted pseudo-labels against Y | |||
| The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended | |||
| to ``self.results``. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| A batch of data examples. | |||
| """ | |||
| pred_pseudo_label_list = data_examples.pred_pseudo_label | |||
| y_list = data_examples.Y | |||
| x_list = data_examples.X | |||
| for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): | |||
| if self.kb._check_equal( | |||
| self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y | |||
| ): | |||
| self.results.append(1) | |||
| else: | |||
| self.results.append(0) | |||
| def compute_metrics(self) -> dict: | |||
| """ | |||
| Compute the reasoning accuracy metrics from ``self.results``. It calculates the | |||
| percentage of correctly reasoned examples over all examples. | |||
| Returns | |||
| ------- | |||
| dict | |||
| A dictionary containing the computed metrics. It includes the key | |||
| 'reasoning_accuracy' which maps to the calculated reasoning accuracy, | |||
| represented as a float between 0 and 1. | |||
| """ | |||
| results = self.results | |||
| metrics = dict() | |||
| metrics["reasoning_accuracy"] = sum(results) / len(results) | |||
| return metrics | |||
| @@ -0,0 +1,73 @@ | |||
| from typing import Optional | |||
| import numpy as np | |||
| from ..structures import ListData | |||
| from .base_metric import BaseMetric | |||
| class SymbolAccuracy(BaseMetric): | |||
| """ | |||
| A metrics class for evaluating symbol-level accuracy. | |||
| This class is designed to assess the accuracy of symbol prediction. Symbol accuracy | |||
| are calculated by comparing predicted presudo labels and their ground truth. | |||
| Parameters | |||
| ---------- | |||
| prefix : str, optional | |||
| The prefix that will be added to the metrics names to disambiguate homonymous | |||
| metrics of different tasks. Inherits from BaseMetric. Default to None. | |||
| """ | |||
| def __init__(self, prefix: Optional[str] = None) -> None: | |||
| super().__init__(prefix) | |||
| def process(self, data_examples: ListData) -> None: | |||
| """ | |||
| Processes a batch of data examples. | |||
| This method takes in a batch of data examples, each containing a list of predicted | |||
| pseudo-labels (pred_pseudo_label) and their ground truth (gt_pseudo_label). It | |||
| calculates the accuracy by comparing the two lists. Then, a tuple of correct symbol | |||
| count and total symbol count is appended to ``self.results``. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| A batch of data examples, each containing: | |||
| - ``pred_pseudo_label``: List of predicted pseudo-labels. | |||
| - ``gt_pseudo_label``: List of ground truth pseudo-labels. | |||
| Raises | |||
| ------ | |||
| ValueError | |||
| If the lengths of predicted and ground truth symbol lists are not equal. | |||
| """ | |||
| pred_pseudo_label_list = data_examples.flatten("pred_pseudo_label") | |||
| gt_pseudo_label_list = data_examples.flatten("gt_pseudo_label") | |||
| if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list): | |||
| raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") | |||
| correct_num = np.sum(np.array(pred_pseudo_label_list) == np.array(gt_pseudo_label_list)) | |||
| self.results.append((correct_num, len(pred_pseudo_label_list))) | |||
| def compute_metrics(self) -> dict: | |||
| """ | |||
| Compute the symbol accuracy metrics from ``self.results``. It calculates the | |||
| percentage of correctly predicted pseudo-labels over all pseudo-labels. | |||
| Returns | |||
| ------- | |||
| dict | |||
| A dictionary containing the computed metrics. It includes the key | |||
| 'character_accuracy' which maps to the calculated symbol-level accuracy, | |||
| represented as a float between 0 and 1. | |||
| """ | |||
| results = self.results | |||
| metrics = dict() | |||
| metrics["character_accuracy"] = sum(t[0] for t in results) / sum(t[1] for t in results) | |||
| return metrics | |||
| @@ -0,0 +1,4 @@ | |||
| from .base_data_element import BaseDataElement | |||
| from .list_data import ListData | |||
| __all__ = ["BaseDataElement", "ListData"] | |||
| @@ -0,0 +1,625 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import copy | |||
| from typing import Any, Iterator, Optional, Tuple, Type, Union | |||
| import numpy as np | |||
| import torch | |||
| # Modified from | |||
| # https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py | |||
| class BaseDataElement: | |||
| """A base data interface that supports Tensor-like and dict-like | |||
| operations. | |||
| A typical data elements refer to predicted results or ground truth labels | |||
| on a task, such as predicted bboxes, instance masks, semantic | |||
| segmentation masks, etc. Because groundtruth labels and predicted results | |||
| often have similar properties (for example, the predicted bboxes and the | |||
| groundtruth bboxes), MMEngine uses the same abstract data interface to | |||
| encapsulate predicted results and groundtruth labels, and it is recommended | |||
| to use different name conventions to distinguish them, such as using | |||
| ``gt_instances`` and ``pred_instances`` to distinguish between labels and | |||
| predicted results. Additionally, we distinguish data elements at instance | |||
| level, pixel level, and label level. Each of these types has its own | |||
| characteristics. Therefore, MMEngine defines the base class | |||
| ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and | |||
| ``LabelData`` inheriting from ``BaseDataElement`` to represent different | |||
| types of ground truth labels or predictions. | |||
| Another common data element is data example. A data example consists of input | |||
| data (such as an image) and its annotations and predictions. In general, | |||
| an image can have multiple types of annotations and/or predictions at the | |||
| same time (for example, both pixel-level semantic segmentation annotations | |||
| and instance-level detection bboxes annotations). All labels and | |||
| predictions of a training example are often passed between Dataset, Model, | |||
| Visualizer, and Evaluator components. In order to simplify the interface | |||
| between components, we can treat them as a large data element and | |||
| encapsulate them. Such data elements are generally called XXDataSample in | |||
| the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |||
| allows `BaseDataElement` as its attribute. Such a class generally | |||
| encapsulates all the data of a example in the algorithm library, and its | |||
| attributes generally are various types of data elements. For example, | |||
| MMDetection is assigned by the BaseDataElement to encapsulate all the data | |||
| elements of the example labeling and prediction of a example in the | |||
| algorithm library. | |||
| The attributes in ``BaseDataElement`` are divided into two parts, | |||
| the ``metainfo`` and the ``data`` respectively. | |||
| - ``metainfo``: Usually contains the | |||
| information about the image such as filename, | |||
| image_shape, pad_shape, etc. The attributes can be accessed or | |||
| modified by dict-like or object-like operations, such as | |||
| ``.`` (for data access and modification), ``in``, ``del``, | |||
| ``pop(str)``, ``get(str)``, ``metainfo_keys()``, | |||
| ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for | |||
| set or change key-value pairs in metainfo). | |||
| - ``data``: Annotations or model predictions are | |||
| stored. The attributes can be accessed or modified by | |||
| dict-like or object-like operations, such as | |||
| ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |||
| ``values()``, ``items()``. Users can also apply tensor-like | |||
| methods to all :obj:`torch.Tensor` in the ``data_fields``, | |||
| such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, | |||
| ``to_tensor()``, ``.detach()``. | |||
| Args: | |||
| metainfo (dict, optional): A dict contains the meta information | |||
| of single image, such as ``dict(img_shape=(512, 512, 3), | |||
| scale_factor=(1, 1, 1, 1))``. Defaults to None. | |||
| kwargs (dict, optional): A dict contains annotations of single image or | |||
| model predictions. Defaults to None. | |||
| Examples: | |||
| >>> import torch | |||
| >>> from mmengine.structures import BaseDataElement | |||
| >>> gt_instances = BaseDataElement() | |||
| >>> bboxes = torch.rand((5, 4)) | |||
| >>> scores = torch.rand((5,)) | |||
| >>> img_id = 0 | |||
| >>> img_shape = (800, 1333) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=img_id, img_shape=img_shape), | |||
| ... bboxes=bboxes, scores=scores) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) | |||
| >>> # new | |||
| >>> gt_instances1 = gt_instances.new( | |||
| ... metainfo=dict(img_id=1, img_shape=(640, 640)), | |||
| ... bboxes=torch.rand((5, 4)), | |||
| ... scores=torch.rand((5,))) | |||
| >>> gt_instances2 = gt_instances1.new() | |||
| >>> # add and process property | |||
| >>> gt_instances = BaseDataElement() | |||
| >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) | |||
| >>> assert 'img_shape' in gt_instances.metainfo_keys() | |||
| >>> assert 'img_shape' in gt_instances | |||
| >>> assert 'img_shape' not in gt_instances.keys() | |||
| >>> assert 'img_shape' in gt_instances.all_keys() | |||
| >>> print(gt_instances.img_shape) | |||
| (100, 100) | |||
| >>> gt_instances.scores = torch.rand((5,)) | |||
| >>> assert 'scores' in gt_instances.keys() | |||
| >>> assert 'scores' in gt_instances | |||
| >>> assert 'scores' in gt_instances.all_keys() | |||
| >>> assert 'scores' not in gt_instances.metainfo_keys() | |||
| >>> print(gt_instances.scores) | |||
| tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) | |||
| >>> gt_instances.bboxes = torch.rand((5, 4)) | |||
| >>> assert 'bboxes' in gt_instances.keys() | |||
| >>> assert 'bboxes' in gt_instances | |||
| >>> assert 'bboxes' in gt_instances.all_keys() | |||
| >>> assert 'bboxes' not in gt_instances.metainfo_keys() | |||
| >>> print(gt_instances.bboxes) | |||
| tensor([[0.0900, 0.0424, 0.1755, 0.4469], | |||
| [0.8648, 0.0592, 0.3484, 0.0913], | |||
| [0.5808, 0.1909, 0.6165, 0.7088], | |||
| [0.5490, 0.4209, 0.9416, 0.2374], | |||
| [0.3652, 0.1218, 0.8805, 0.7523]]) | |||
| >>> # delete and change property | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=dict(img_id=0, img_shape=(640, 640)), | |||
| ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) | |||
| >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) | |||
| >>> gt_instances.img_shape # (1280, 1280) | |||
| >>> gt_instances.bboxes = gt_instances.bboxes * 2 | |||
| >>> gt_instances.get('img_shape', None) # (1280, 1280) | |||
| >>> gt_instances.get('bboxes', None) # 6x4 tensor | |||
| >>> del gt_instances.img_shape | |||
| >>> del gt_instances.bboxes | |||
| >>> assert 'img_shape' not in gt_instances | |||
| >>> assert 'bboxes' not in gt_instances | |||
| >>> gt_instances.pop('img_shape', None) # None | |||
| >>> gt_instances.pop('bboxes', None) # None | |||
| >>> # Tensor-like | |||
| >>> cuda_instances = gt_instances.cuda() | |||
| >>> cuda_instances = gt_instances.to('cuda:0') | |||
| >>> cpu_instances = cuda_instances.cpu() | |||
| >>> cpu_instances = cuda_instances.to('cpu') | |||
| >>> fp16_instances = cuda_instances.to( | |||
| ... device=None, dtype=torch.float16, non_blocking=False, | |||
| ... copy=False, memory_format=torch.preserve_format) | |||
| >>> cpu_instances = cuda_instances.detach() | |||
| >>> np_instances = cpu_instances.numpy() | |||
| >>> metainfo = dict(img_shape=(800, 1196, 3)) | |||
| >>> gt_instances = BaseDataElement( | |||
| ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |||
| >>> example = BaseDataElement(metainfo=metainfo, | |||
| ... gt_instances=gt_instances) | |||
| >>> print(example) | |||
| <BaseDataElement( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| DATA FIELDS | |||
| gt_instances: <BaseDataElement( | |||
| META INFORMATION | |||
| img_shape: (800, 1196, 3) | |||
| DATA FIELDS | |||
| det_labels: tensor([0, 1, 2, 3]) | |||
| ) at 0x7f0ec5eadc70> | |||
| ) at 0x7f0fea49e130> | |||
| >>> # inheritance | |||
| >>> class DetDataSample(BaseDataElement): | |||
| ... @property | |||
| ... def proposals(self): | |||
| ... return self._proposals | |||
| ... @proposals.setter | |||
| ... def proposals(self, value): | |||
| ... self.set_field(value, '_proposals', dtype=BaseDataElement) | |||
| ... @proposals.deleter | |||
| ... def proposals(self): | |||
| ... del self._proposals | |||
| ... @property | |||
| ... def gt_instances(self): | |||
| ... return self._gt_instances | |||
| ... @gt_instances.setter | |||
| ... def gt_instances(self, value): | |||
| ... self.set_field(value, '_gt_instances', | |||
| ... dtype=BaseDataElement) | |||
| ... @gt_instances.deleter | |||
| ... def gt_instances(self): | |||
| ... del self._gt_instances | |||
| ... @property | |||
| ... def pred_instances(self): | |||
| ... return self._pred_instances | |||
| ... @pred_instances.setter | |||
| ... def pred_instances(self, value): | |||
| ... self.set_field(value, '_pred_instances', | |||
| ... dtype=BaseDataElement) | |||
| ... @pred_instances.deleter | |||
| ... def pred_instances(self): | |||
| ... del self._pred_instances | |||
| >>> det_example = DetDataSample() | |||
| >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |||
| >>> det_example.proposals = proposals | |||
| >>> assert 'proposals' in det_example | |||
| >>> assert det_example.proposals == proposals | |||
| >>> del det_example.proposals | |||
| >>> assert 'proposals' not in det_example | |||
| >>> with self.assertRaises(AssertionError): | |||
| ... det_example.proposals = torch.rand((5, 4)) | |||
| """ | |||
| def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |||
| self._metainfo_fields: set = set() | |||
| self._data_fields: set = set() | |||
| if metainfo is not None: | |||
| self.set_metainfo(metainfo=metainfo) | |||
| if kwargs: | |||
| self.set_data(kwargs) | |||
| def set_metainfo(self, metainfo: dict) -> None: | |||
| """Set or change key-value pairs in ``metainfo_field`` by parameter | |||
| ``metainfo``. | |||
| Args: | |||
| metainfo (dict): A dict contains the meta information | |||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||
| """ | |||
| assert isinstance(metainfo, dict), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||
| meta = copy.deepcopy(metainfo) | |||
| for k, v in meta.items(): | |||
| self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | |||
| def set_data(self, data: dict) -> None: | |||
| """Set or change key-value pairs in ``data_field`` by parameter | |||
| ``data``. | |||
| Args: | |||
| data (dict): A dict contains annotations of image or | |||
| model predictions. | |||
| """ | |||
| assert isinstance(data, dict), f"data should be a `dict` but got {data}" | |||
| for k, v in data.items(): | |||
| # Use `setattr()` rather than `self.set_field` to allow `set_data` | |||
| # to set property method. | |||
| setattr(self, k, v) | |||
| def update(self, instance: "BaseDataElement") -> None: | |||
| """The update() method updates the BaseDataElement with the elements | |||
| from another BaseDataElement object. | |||
| Args: | |||
| instance (BaseDataElement): Another BaseDataElement object for | |||
| update the current object. | |||
| """ | |||
| assert isinstance( | |||
| instance, BaseDataElement | |||
| ), f"instance should be a `BaseDataElement` but got {type(instance)}" | |||
| self.set_metainfo(dict(instance.metainfo_items())) | |||
| self.set_data(dict(instance.items())) | |||
| def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": | |||
| """Return a new data element with same type. If ``metainfo`` and | |||
| ``data`` are None, the new data element will have same metainfo and | |||
| data. If metainfo or data is not None, the new result will overwrite it | |||
| with the input value. | |||
| Args: | |||
| metainfo (dict, optional): A dict contains the meta information | |||
| of image, such as ``img_shape``, ``scale_factor``, etc. | |||
| Defaults to None. | |||
| kwargs (dict): A dict contains annotations of image or | |||
| model predictions. | |||
| Returns: | |||
| BaseDataElement: A new data element with same type. | |||
| """ | |||
| new_data = self.__class__() | |||
| if metainfo is not None: | |||
| new_data.set_metainfo(metainfo) | |||
| else: | |||
| new_data.set_metainfo(dict(self.metainfo_items())) | |||
| if kwargs: | |||
| new_data.set_data(kwargs) | |||
| else: | |||
| new_data.set_data(dict(self.items())) | |||
| return new_data | |||
| def clone(self): | |||
| """Deep copy the current data element. | |||
| Returns: | |||
| BaseDataElement: The copy of current data element. | |||
| """ | |||
| clone_data = self.__class__() | |||
| clone_data.set_metainfo(dict(self.metainfo_items())) | |||
| clone_data.set_data(dict(self.items())) | |||
| return clone_data | |||
| def keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in data_fields. | |||
| """ | |||
| # We assume that the name of the attribute related to property is | |||
| # '_' + the name of the property. We use this rule to filter out | |||
| # private keys. | |||
| # TODO: Use a more robust way to solve this problem | |||
| private_keys = { | |||
| "_" + key | |||
| for key in self._data_fields | |||
| if isinstance(getattr(type(self), key, None), property) | |||
| } | |||
| return list(self._data_fields - private_keys) | |||
| def metainfo_keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in metainfo_fields. | |||
| """ | |||
| return list(self._metainfo_fields) | |||
| def values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in data. | |||
| """ | |||
| return [getattr(self, k) for k in self.keys()] | |||
| def metainfo_values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in metainfo. | |||
| """ | |||
| return [getattr(self, k) for k in self.metainfo_keys()] | |||
| def all_keys(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all keys in metainfo and data. | |||
| """ | |||
| return self.metainfo_keys() + self.keys() | |||
| def all_values(self) -> list: | |||
| """ | |||
| Returns: | |||
| list: Contains all values in metainfo and data. | |||
| """ | |||
| return self.metainfo_values() + self.values() | |||
| def all_items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``metainfo`` and ``data``. | |||
| """ | |||
| for k in self.all_keys(): | |||
| yield (k, getattr(self, k)) | |||
| def items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``data``. | |||
| """ | |||
| for k in self.keys(): | |||
| yield (k, getattr(self, k)) | |||
| def metainfo_items(self) -> Iterator[Tuple[str, Any]]: | |||
| """ | |||
| Returns: | |||
| iterator: An iterator object whose element is (key, value) tuple | |||
| pairs for ``metainfo``. | |||
| """ | |||
| for k in self.metainfo_keys(): | |||
| yield (k, getattr(self, k)) | |||
| @property | |||
| def metainfo(self) -> dict: | |||
| """dict: A dict contains metainfo of current data element.""" | |||
| return dict(self.metainfo_items()) | |||
| def __setattr__(self, name: str, value: Any): | |||
| """setattr is only used to set data.""" | |||
| if name in ("_metainfo_fields", "_data_fields"): | |||
| if not hasattr(self, name): | |||
| super().__setattr__(name, value) | |||
| else: | |||
| raise AttributeError( | |||
| f"{name} has been used as a " "private attribute, which is immutable." | |||
| ) | |||
| else: | |||
| self.set_field(name=name, value=value, field_type="data", dtype=None) | |||
| def __delattr__(self, item: str): | |||
| """Delete the item in dataelement. | |||
| Args: | |||
| item (str): The key to delete. | |||
| """ | |||
| if item in ("_metainfo_fields", "_data_fields"): | |||
| raise AttributeError( | |||
| f"{item} has been used as a " "private attribute, which is immutable." | |||
| ) | |||
| super().__delattr__(item) | |||
| if item in self._metainfo_fields: | |||
| self._metainfo_fields.remove(item) | |||
| elif item in self._data_fields: | |||
| self._data_fields.remove(item) | |||
| # dict-like methods | |||
| __delitem__ = __delattr__ | |||
| def get(self, key, default=None) -> Any: | |||
| """Get property in data and metainfo as the same as python.""" | |||
| # Use `getattr()` rather than `self.__dict__.get()` to allow getting | |||
| # properties. | |||
| return getattr(self, key, default) | |||
| def pop(self, *args) -> Any: | |||
| """Pop property in data and metainfo as the same as python.""" | |||
| assert len(args) < 3, "``pop`` get more than 2 arguments" | |||
| name = args[0] | |||
| if name in self._metainfo_fields: | |||
| self._metainfo_fields.remove(args[0]) | |||
| return self.__dict__.pop(*args) | |||
| elif name in self._data_fields: | |||
| self._data_fields.remove(args[0]) | |||
| return self.__dict__.pop(*args) | |||
| # with default value | |||
| elif len(args) == 2: | |||
| return args[1] | |||
| else: | |||
| # don't just use 'self.__dict__.pop(*args)' for only popping key in | |||
| # metainfo or data | |||
| raise KeyError(f"{args[0]} is not contained in metainfo or data") | |||
| def __contains__(self, item: str) -> bool: | |||
| """Whether the item is in dataelement. | |||
| Args: | |||
| item (str): The key to inquire. | |||
| """ | |||
| return item in self._data_fields or item in self._metainfo_fields | |||
| def set_field( | |||
| self, | |||
| value: Any, | |||
| name: str, | |||
| dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |||
| field_type: str = "data", | |||
| ) -> None: | |||
| """Special method for set union field, used as property.setter | |||
| functions.""" | |||
| assert field_type in ["metainfo", "data"] | |||
| if dtype is not None: | |||
| assert isinstance(value, dtype), f"{value} should be a {dtype} but got {type(value)}" | |||
| if field_type == "metainfo": | |||
| if name in self._data_fields: | |||
| raise AttributeError( | |||
| f"Cannot set {name} to be a field of metainfo " | |||
| f"because {name} is already a data field" | |||
| ) | |||
| self._metainfo_fields.add(name) | |||
| else: | |||
| if name in self._metainfo_fields: | |||
| raise AttributeError( | |||
| f"Cannot set {name} to be a field of data " | |||
| f"because {name} is already a metainfo field" | |||
| ) | |||
| self._data_fields.add(name) | |||
| super().__setattr__(name, value) | |||
| # Tensor-like methods | |||
| def to(self, *args, **kwargs) -> "BaseDataElement": | |||
| """Apply same name function to all tensors in data_fields.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if hasattr(v, "to"): | |||
| v = v.to(*args, **kwargs) | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def cpu(self) -> "BaseDataElement": | |||
| """Convert all tensors to CPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.cpu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def cuda(self) -> "BaseDataElement": | |||
| """Convert all tensors to GPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.cuda() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def npu(self) -> "BaseDataElement": | |||
| """Convert all tensors to NPU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.npu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def mlu(self) -> "BaseDataElement": | |||
| """Convert all tensors to MLU in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.mlu() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def detach(self) -> "BaseDataElement": | |||
| """Detach all tensors in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.detach() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| # Tensor-like methods | |||
| def numpy(self) -> "BaseDataElement": | |||
| """Convert all tensors to np.ndarray in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
| v = v.detach().cpu().numpy() | |||
| data = {k: v} | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def to_tensor(self) -> "BaseDataElement": | |||
| """Convert all np.ndarray to tensor in data.""" | |||
| new_data = self.new() | |||
| for k, v in self.items(): | |||
| data = {} | |||
| if isinstance(v, np.ndarray): | |||
| v = torch.from_numpy(v) | |||
| data[k] = v | |||
| elif isinstance(v, BaseDataElement): | |||
| v = v.to_tensor() | |||
| data[k] = v | |||
| new_data.set_data(data) | |||
| return new_data | |||
| def to_dict(self) -> dict: | |||
| """Convert BaseDataElement to dict.""" | |||
| return { | |||
| k: v.to_dict() if isinstance(v, BaseDataElement) else v for k, v in self.all_items() | |||
| } | |||
| def __repr__(self) -> str: | |||
| """Represent the object.""" | |||
| def _addindent(s_: str, num_spaces: int) -> str: | |||
| """This func is modified from `pytorch` https://github.com/pytorch/ | |||
| pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu | |||
| les/module.py#L29. | |||
| Args: | |||
| s_ (str): The string to add spaces. | |||
| num_spaces (int): The num of space to add. | |||
| Returns: | |||
| str: The string after add indent. | |||
| """ | |||
| s = s_.split("\n") | |||
| # don't do anything for single-line stuff | |||
| if len(s) == 1: | |||
| return s_ | |||
| first = s.pop(0) | |||
| s = [(num_spaces * " ") + line for line in s] | |||
| s = "\n".join(s) # type: ignore | |||
| s = first + "\n" + s # type: ignore | |||
| return s # type: ignore | |||
| def dump(obj: Any) -> str: | |||
| """Represent the object. | |||
| Args: | |||
| obj (Any): The obj to represent. | |||
| Returns: | |||
| str: The represented str. | |||
| """ | |||
| _repr = "" | |||
| if isinstance(obj, dict): | |||
| for k, v in obj.items(): | |||
| _repr += f"\n{k}: {_addindent(dump(v), 4)}" | |||
| elif isinstance(obj, BaseDataElement): | |||
| _repr += "\n\n META INFORMATION" | |||
| metainfo_items = dict(obj.metainfo_items()) | |||
| _repr += _addindent(dump(metainfo_items), 4) | |||
| _repr += "\n\n DATA FIELDS" | |||
| items = dict(obj.items()) | |||
| _repr += _addindent(dump(items), 4) | |||
| classname = obj.__class__.__name__ | |||
| _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" | |||
| else: | |||
| _repr += repr(obj) | |||
| return _repr | |||
| return dump(self) | |||
| @@ -0,0 +1,257 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| from typing import List, Union | |||
| import numpy as np | |||
| import torch | |||
| from ...utils import flatten as flatten_list | |||
| from ...utils import to_hashable | |||
| from .base_data_element import BaseDataElement | |||
| BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |||
| LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |||
| IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] | |||
| # Modified from | |||
| # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||
| class ListData(BaseDataElement): | |||
| """ | |||
| Abstract Data Interface used throughout the ABL-Package. | |||
| ``ListData`` is the underlying data structure used in the ABL-Package, | |||
| designed to manage diverse forms of data dynamically generated throughout the | |||
| Abductive Learning (ABL) framework. This includes handling raw data, predicted | |||
| pseudo-labels, abduced pseudo-labels, pseudo-label indices, etc. | |||
| As a fundamental data structure in ABL, ``ListData`` is essential for the smooth | |||
| transfer and manipulation of data across various components of the ABL framework, | |||
| such as prediction, abductive reasoning, and training phases. It provides a | |||
| unified data format across these stages, ensuring compatibility and flexibility | |||
| in handling diverse data forms in the ABL framework. | |||
| The attributes in ``ListData`` are divided into two parts, | |||
| the ``metainfo`` and the ``data`` respectively. | |||
| - ``metainfo``: Usually used to store basic information about data examples, | |||
| such as symbol number, image size, etc. The attributes can be accessed or | |||
| modified by dict-like or object-like operations, such as ``.`` (for data | |||
| access and modification), ``in``, ``del``, ``pop(str)``, ``get(str)``, | |||
| ``metainfo_keys()``, ``metainfo_values()``, ``metainfo_items()``, | |||
| ``set_metainfo()`` (for set or change key-value pairs in metainfo). | |||
| - ``data``: raw data, labels, predictions, and abduced results are stored. | |||
| The attributes can be accessed or modified by dict-like or object-like operations, | |||
| such as ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |||
| ``values()``, ``items()``. Users can also apply tensor-like | |||
| methods to all :obj:`torch.Tensor` in the ``data_fields``, such as ``.cuda()``, | |||
| ``.cpu()``, ``.numpy()``, ``.to()``, ``to_tensor()``, ``.detach()``. | |||
| ListData supports ``index`` and ``slice`` for data field. The type of value in | |||
| data field can be either ``None`` or ``list`` of base data structures such as | |||
| ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``. | |||
| This design is inspired by and extends the functionalities of the ``BaseDataElement`` | |||
| class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 | |||
| Examples: | |||
| >>> from abl.data.structures import ListData | |||
| >>> import numpy as np | |||
| >>> import torch | |||
| >>> data_examples = ListData() | |||
| >>> data_examples.X = [list(torch.randn(2)) for _ in range(3)] | |||
| >>> data_examples.Y = [1, 2, 3] | |||
| >>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] | |||
| >>> len(data_examples) | |||
| 3 | |||
| >>> print(data_examples) | |||
| <ListData( | |||
| META INFORMATION | |||
| DATA FIELDS | |||
| Y: [1, 2, 3] | |||
| gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] | |||
| X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 | |||
| ) at 0x7f3bbf1991c0> | |||
| >>> print(data_examples[:1]) | |||
| <ListData( | |||
| META INFORMATION | |||
| DATA FIELDS | |||
| Y: [1] | |||
| gt_pseudo_label: [[1, 2]] | |||
| X: [[tensor(1.1949), tensor(-0.9378)]] | |||
| ) at 0x7f3bbf1a3580> | |||
| >>> print(data_examples.elements_num("X")) | |||
| 6 | |||
| >>> print(data_examples.flatten("gt_pseudo_label")) | |||
| [1, 2, 3, 4, 5, 6] | |||
| >>> print(data_examples.to_tuple("Y")) | |||
| (1, 2, 3) | |||
| """ | |||
| def __setattr__(self, name: str, value: list): | |||
| """setattr is only used to set data. | |||
| The value must have the attribute of `__len__` and have the same length | |||
| of `ListData`. | |||
| """ | |||
| if name in ("_metainfo_fields", "_data_fields"): | |||
| if not hasattr(self, name): | |||
| super().__setattr__(name, value) | |||
| else: | |||
| raise AttributeError( | |||
| f"{name} has been used as a " "private attribute, which is immutable." | |||
| ) | |||
| else: | |||
| # assert isinstance(value, list), "value must be of type `list`" | |||
| # if len(self) > 0: | |||
| # assert len(value) == len(self), ( | |||
| # "The length of " | |||
| # f"values {len(value)} is " | |||
| # "not consistent with " | |||
| # "the length of this " | |||
| # ":obj:`ListData` " | |||
| # f"{len(self)}" | |||
| # ) | |||
| super().__setattr__(name, value) | |||
| __setitem__ = __setattr__ | |||
| def __getitem__(self, item: IndexType) -> "ListData": | |||
| """ | |||
| Args: | |||
| item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |||
| :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |||
| Get the corresponding values according to item. | |||
| Returns: | |||
| :obj:`ListData`: Corresponding values. | |||
| """ | |||
| assert isinstance(item, IndexType.__args__) | |||
| if isinstance(item, list): | |||
| item = np.array(item) | |||
| if isinstance(item, np.ndarray): | |||
| # The default int type of numpy is platform dependent, int32 for | |||
| # windows and int64 for linux. `torch.Tensor` requires the index | |||
| # should be int64, therefore we simply convert it to int64 here. | |||
| # More details in https://github.com/numpy/numpy/issues/9464 | |||
| item = item.astype(np.int64) if item.dtype == np.int32 else item | |||
| item = torch.from_numpy(item) | |||
| if isinstance(item, str): | |||
| return getattr(self, item) | |||
| new_data = self.__class__(metainfo=self.metainfo) | |||
| if isinstance(item, torch.Tensor): | |||
| assert item.dim() == 1, "Only support to get the" " values along the first dimension." | |||
| for k, v in self.items(): | |||
| if v is None: | |||
| new_data[k] = None | |||
| elif isinstance(v, torch.Tensor): | |||
| new_data[k] = v[item] | |||
| elif isinstance(v, np.ndarray): | |||
| new_data[k] = v[item.cpu().numpy()] | |||
| elif isinstance(v, (str, list, tuple)) or ( | |||
| hasattr(v, "__getitem__") and hasattr(v, "cat") | |||
| ): | |||
| # convert to indexes from BoolTensor | |||
| if isinstance(item, BoolTypeTensor.__args__): | |||
| indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() | |||
| else: | |||
| indexes = item.cpu().numpy().tolist() | |||
| slice_list = [] | |||
| if indexes: | |||
| for index in indexes: | |||
| slice_list.append(slice(index, None, len(v))) | |||
| else: | |||
| slice_list.append(slice(None, 0, None)) | |||
| r_list = [v[s] for s in slice_list] | |||
| if isinstance(v, (str, list, tuple)): | |||
| new_value = r_list[0] | |||
| for r in r_list[1:]: | |||
| new_value = new_value + r | |||
| else: | |||
| new_value = v.cat(r_list) | |||
| new_data[k] = new_value | |||
| else: | |||
| raise ValueError( | |||
| f"The type of `{k}` is `{type(v)}`, which has no " | |||
| "attribute of `cat`, so it does not " | |||
| "support slice with `bool`" | |||
| ) | |||
| else: | |||
| # item is a slice or int | |||
| for k, v in self.items(): | |||
| if v is None: | |||
| new_data[k] = None | |||
| else: | |||
| new_data[k] = v[item] | |||
| return new_data # type:ignore | |||
| def flatten(self, item: str) -> List: | |||
| """ | |||
| Flatten the list of the attribute specified by ``item``. | |||
| Parameters | |||
| ---------- | |||
| item | |||
| Name of the attribute to be flattened. | |||
| Returns | |||
| ------- | |||
| list | |||
| The flattened list of the attribute specified by ``item``. | |||
| """ | |||
| return flatten_list(self[item]) | |||
| def elements_num(self, item: str) -> int: | |||
| """ | |||
| Return the number of elements in the attribute specified by ``item``. | |||
| Parameters | |||
| ---------- | |||
| item : str | |||
| Name of the attribute for which the number of elements is to be determined. | |||
| Returns | |||
| ------- | |||
| int | |||
| The number of elements in the attribute specified by ``item``. | |||
| """ | |||
| return len(self.flatten(item)) | |||
| def to_tuple(self, item: str) -> tuple: | |||
| """ | |||
| Convert the attribute specified by ``item`` to a tuple. | |||
| Parameters | |||
| ---------- | |||
| item : str | |||
| Name of the attribute to be converted. | |||
| Returns | |||
| ------- | |||
| tuple | |||
| The attribute after conversion to a tuple. | |||
| """ | |||
| return to_hashable(self[item]) | |||
| def __len__(self) -> int: | |||
| """int: The length of ListData.""" | |||
| iterator = iter(self._data_fields) | |||
| data = next(iterator) | |||
| while getattr(self, data) is None: | |||
| try: | |||
| data = next(iterator) | |||
| except StopIteration: | |||
| break | |||
| if getattr(self, data) is None: | |||
| raise ValueError("All data fields are None.") | |||
| else: | |||
| return len(getattr(self, data)) | |||
| @@ -0,0 +1,5 @@ | |||
| from .abl_model import ABLModel | |||
| from .basic_nn import BasicNN | |||
| from .torch_dataset import ClassificationDataset, PredictionDataset, RegressionDataset | |||
| __all__ = ["ABLModel", "BasicNN", "ClassificationDataset", "PredictionDataset", "RegressionDataset"] | |||
| @@ -0,0 +1,134 @@ | |||
| import pickle | |||
| from typing import Any, Dict | |||
| from ..data.structures import ListData | |||
| from ..utils import reform_list | |||
| class ABLModel: | |||
| """ | |||
| Serialize data and provide a unified interface for different machine learning models. | |||
| Parameters | |||
| ---------- | |||
| base_model : Machine Learning Model | |||
| The machine learning base model used for training and prediction. This model should | |||
| implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for the | |||
| model to also implement the ``predict_proba`` method for generating | |||
| predictions on the probabilities. | |||
| """ | |||
| def __init__(self, base_model: Any) -> None: | |||
| if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): | |||
| raise NotImplementedError("The base_model should implement fit and predict methods.") | |||
| self.base_model = base_model | |||
| def predict(self, data_examples: ListData) -> Dict: | |||
| """ | |||
| Predict the labels and probabilities for the given data. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| A batch of data to predict on. | |||
| Returns | |||
| ------- | |||
| dict | |||
| A dictionary containing the predicted labels and probabilities. | |||
| """ | |||
| model = self.base_model | |||
| data_X = data_examples.flatten("X") | |||
| if hasattr(model, "predict_proba"): | |||
| prob = model.predict_proba(X=data_X) | |||
| label = prob.argmax(axis=1) | |||
| prob = reform_list(prob, data_examples.X) | |||
| else: | |||
| prob = None | |||
| label = model.predict(X=data_X) | |||
| label = reform_list(label, data_examples.X) | |||
| data_examples.pred_idx = label | |||
| data_examples.pred_prob = prob | |||
| return {"label": label, "prob": prob} | |||
| def train(self, data_examples: ListData) -> float: | |||
| """ | |||
| Train the model on the given data. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| A batch of data to train on, which typically contains the data, ``X``, and the | |||
| corresponding labels, ``abduced_idx``. | |||
| Returns | |||
| ------- | |||
| float | |||
| The loss value of the trained model. | |||
| """ | |||
| data_X = data_examples.flatten("X") | |||
| data_y = data_examples.flatten("abduced_idx") | |||
| return self.base_model.fit(X=data_X, y=data_y) | |||
| def valid(self, data_examples: ListData) -> float: | |||
| """ | |||
| Validate the model on the given data. | |||
| Parameters | |||
| ---------- | |||
| data_examples : ListData | |||
| A batch of data to train on, which typically contains the data, ``X``, | |||
| and the corresponding labels, ``abduced_idx``. | |||
| Returns | |||
| ------- | |||
| float | |||
| The accuracy the trained model. | |||
| """ | |||
| data_X = data_examples.flatten("X") | |||
| data_y = data_examples.flatten("abduced_idx") | |||
| score = self.base_model.score(X=data_X, y=data_y) | |||
| return score | |||
| def _model_operation(self, operation: str, *args, **kwargs): | |||
| model = self.base_model | |||
| if hasattr(model, operation): | |||
| method = getattr(model, operation) | |||
| method(*args, **kwargs) | |||
| else: | |||
| if f"{operation}_path" not in kwargs.keys(): | |||
| raise ValueError(f"'{operation}_path' should not be None") | |||
| else: | |||
| try: | |||
| if operation == "save": | |||
| with open(kwargs["save_path"], "wb") as file: | |||
| pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
| elif operation == "load": | |||
| with open(kwargs["load_path"], "rb") as file: | |||
| self.base_model = pickle.load(file) | |||
| except (OSError, pickle.PickleError): | |||
| raise NotImplementedError( | |||
| f"{type(model).__name__} object doesn't have the {operation} method \ | |||
| and the default pickle-based {operation} method failed." | |||
| ) | |||
| def save(self, *args, **kwargs) -> None: | |||
| """ | |||
| Save the model to a file. | |||
| This method delegates to the ``save`` method of self.base_model. The arguments passed to | |||
| this method should match those expected by the ``save`` method of self.base_model. | |||
| """ | |||
| self._model_operation("save", *args, **kwargs) | |||
| def load(self, *args, **kwargs) -> None: | |||
| """ | |||
| Load the model from a file. | |||
| This method delegates to the ``load`` method of self.base_model. The arguments passed to | |||
| this method should match those expected by the ``load`` method of self.base_model. | |||
| """ | |||
| self._model_operation("load", *args, **kwargs) | |||
| @@ -0,0 +1,546 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| import os | |||
| from typing import Any, Callable, List, Optional, Tuple, Union | |||
| import numpy | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| from ..utils.logger import print_log | |||
| from .torch_dataset import ClassificationDataset, PredictionDataset | |||
| class BasicNN: | |||
| """ | |||
| Wrap NN models into the form of an sklearn estimator. | |||
| Parameters | |||
| ---------- | |||
| model : torch.nn.Module | |||
| The PyTorch model to be trained or used for prediction. | |||
| loss_fn : torch.nn.Module | |||
| The loss function used for training. | |||
| optimizer : torch.optim.Optimizer | |||
| The optimizer used for training. | |||
| scheduler : Callable[..., Any], optional | |||
| The learning rate scheduler used for training, which will be called | |||
| at the end of each run of the ``fit`` method. It should implement the | |||
| ``step`` method, by default None. | |||
| device : Union[torch.device, str] | |||
| The device on which the model will be trained or used for prediction, | |||
| by default torch.device("cpu"). | |||
| batch_size : int, optional | |||
| The batch size used for training, by default 32. | |||
| num_epochs : int, optional | |||
| The number of epochs used for training, by default 1. | |||
| stop_loss : float, optional | |||
| The loss value at which to stop training, by default 0.0001. | |||
| num_workers : int | |||
| The number of workers used for loading data, by default 0. | |||
| save_interval : int, optional | |||
| The model will be saved every ``save_interval`` epochs during training, by default None. | |||
| save_dir : str, optional | |||
| The directory in which to save the model during training, by default None. | |||
| train_transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version used | |||
| in the ``fit`` and ``train_epoch`` methods, by default None. | |||
| test_transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version in the | |||
| ``predict``, ``predict_proba`` and ``score`` methods, , by default None. | |||
| collate_fn : Callable[[List[T]], Any], optional | |||
| The function used to collate data, by default None. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| model: torch.nn.Module, | |||
| loss_fn: torch.nn.Module, | |||
| optimizer: torch.optim.Optimizer, | |||
| scheduler: Optional[Callable[..., Any]] = None, | |||
| device: Union[torch.device, str] = torch.device("cpu"), | |||
| batch_size: int = 32, | |||
| num_epochs: int = 1, | |||
| stop_loss: Optional[float] = 0.0001, | |||
| num_workers: int = 0, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| train_transform: Optional[Callable[..., Any]] = None, | |||
| test_transform: Optional[Callable[..., Any]] = None, | |||
| collate_fn: Optional[Callable[[List[Any]], Any]] = None, | |||
| ) -> None: | |||
| if not isinstance(model, torch.nn.Module): | |||
| raise TypeError("model must be an instance of torch.nn.Module") | |||
| if not isinstance(loss_fn, torch.nn.Module): | |||
| raise TypeError("loss_fn must be an instance of torch.nn.Module") | |||
| if not isinstance(optimizer, torch.optim.Optimizer): | |||
| raise TypeError("optimizer must be an instance of torch.optim.Optimizer") | |||
| if scheduler is not None and not hasattr(scheduler, "step"): | |||
| raise NotImplementedError("scheduler should implement the ``step`` method") | |||
| if not isinstance(device, torch.device): | |||
| if not isinstance(device, str): | |||
| raise TypeError( | |||
| "device must be an instance of torch.device or a str indicating " | |||
| + "the target device" | |||
| ) | |||
| else: | |||
| device = torch.device(device) | |||
| if not isinstance(batch_size, int): | |||
| raise TypeError("batch_size must be an integer") | |||
| if not isinstance(num_epochs, int): | |||
| raise TypeError("num_epochs must be an integer") | |||
| if stop_loss is not None and not isinstance(stop_loss, float): | |||
| raise TypeError("stop_loss must be a float") | |||
| if not isinstance(num_workers, int): | |||
| raise TypeError("num_workers must be an integer") | |||
| if save_interval is not None and not isinstance(save_interval, int): | |||
| raise TypeError("save_interval must be an integer") | |||
| if save_dir is not None and not isinstance(save_dir, str): | |||
| raise TypeError("save_dir must be a string") | |||
| if train_transform is not None and not callable(train_transform): | |||
| raise TypeError("train_transform must be callable") | |||
| if test_transform is not None and not callable(test_transform): | |||
| raise TypeError("test_transform must be callable") | |||
| if collate_fn is not None and not callable(collate_fn): | |||
| raise TypeError("collate_fn must be callable") | |||
| self.model = model.to(device) | |||
| self.loss_fn = loss_fn | |||
| self.optimizer = optimizer | |||
| self.scheduler = scheduler | |||
| self.device = device | |||
| self.batch_size = batch_size | |||
| self.num_epochs = num_epochs | |||
| self.stop_loss = stop_loss | |||
| self.num_workers = num_workers | |||
| self.save_interval = save_interval | |||
| self.save_dir = save_dir | |||
| self.train_transform = train_transform | |||
| self.test_transform = test_transform | |||
| self.collate_fn = collate_fn | |||
| if self.save_interval is not None and self.save_dir is None: | |||
| raise ValueError("save_dir should not be None if save_interval is not None.") | |||
| if self.train_transform is not None and self.test_transform is None: | |||
| print_log( | |||
| "Transform used in the training phase will be used in prediction.", | |||
| logger="current", | |||
| level=logging.WARNING, | |||
| ) | |||
| self.test_transform = self.train_transform | |||
| def _fit(self, data_loader: DataLoader) -> BasicNN: | |||
| """ | |||
| Internal method to fit the model on data for ``self.num_epochs`` times, | |||
| with early stopping. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader | |||
| Data loader providing training samples. | |||
| Returns | |||
| ------- | |||
| BasicNN | |||
| The model itself after training. | |||
| """ | |||
| if not isinstance(data_loader, DataLoader): | |||
| raise TypeError( | |||
| f"data_loader must be an instance of torch.utils.data.DataLoader, " | |||
| f"but got {type(data_loader)}" | |||
| ) | |||
| for epoch in range(self.num_epochs): | |||
| loss_value = self.train_epoch(data_loader) | |||
| if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: | |||
| self.save(epoch + 1) | |||
| if self.stop_loss is not None and loss_value < self.stop_loss: | |||
| break | |||
| if self.scheduler is not None: | |||
| self.scheduler.step() | |||
| print_log(f"model loss: {loss_value:.5f}", logger="current") | |||
| return self | |||
| def fit( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| y: Optional[List[int]] = None, | |||
| ) -> BasicNN: | |||
| """ | |||
| Train the model for self.num_epochs times or until the average loss on one epoch | |||
| is less than self.stop_loss. It supports training with either a DataLoader | |||
| object (data_loader) or a pair of input data (X) and target labels (y). If both | |||
| data_loader and (X, y) are provided, the method will prioritize using the data_loader. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader, optional | |||
| The data loader used for training, by default None. | |||
| X : List[Any], optional | |||
| The input data, by default None. | |||
| y : List[int], optional | |||
| The target data, by default None. | |||
| Returns | |||
| ------- | |||
| BasicNN | |||
| The model itself after training. | |||
| """ | |||
| if data_loader is not None and X is not None: | |||
| print_log( | |||
| "data_loader will be used to train the model instead of X and y.", | |||
| logger="current", | |||
| level=logging.WARNING, | |||
| ) | |||
| if data_loader is None: | |||
| if X is None: | |||
| raise ValueError("data_loader and X can not be None simultaneously.") | |||
| else: | |||
| data_loader = self._data_loader(X, y) | |||
| return self._fit(data_loader) | |||
| def train_epoch(self, data_loader: DataLoader) -> float: | |||
| """ | |||
| Train the model with an instance of DataLoader (data_loader) for one epoch. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader | |||
| The data loader used for training. | |||
| Returns | |||
| ------- | |||
| float | |||
| The average loss on one epoch. | |||
| """ | |||
| model = self.model | |||
| loss_fn = self.loss_fn | |||
| optimizer = self.optimizer | |||
| device = self.device | |||
| model.train() | |||
| total_loss, total_num = 0.0, 0 | |||
| for data, target in data_loader: | |||
| data, target = data.to(device), target.to(device) | |||
| out = model(data) | |||
| loss = loss_fn(out, target) | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| optimizer.step() | |||
| total_loss += loss.item() * data.size(0) | |||
| total_num += data.size(0) | |||
| return total_loss / total_num | |||
| def _predict(self, data_loader: DataLoader) -> torch.Tensor: | |||
| """ | |||
| Internal method to predict the outputs given a DataLoader. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader | |||
| The DataLoader providing input samples. | |||
| Returns | |||
| ------- | |||
| torch.Tensor | |||
| Raw output from the model. | |||
| """ | |||
| if not isinstance(data_loader, DataLoader): | |||
| raise TypeError( | |||
| f"data_loader must be an instance of torch.utils.data.DataLoader, " | |||
| f"but got {type(data_loader)}" | |||
| ) | |||
| model = self.model | |||
| device = self.device | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| results = [] | |||
| for data in data_loader: | |||
| data = data.to(device) | |||
| out = model(data) | |||
| results.append(out) | |||
| return torch.cat(results, axis=0) | |||
| def predict( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| ) -> numpy.ndarray: | |||
| """ | |||
| Predict the class of the input data. This method supports prediction with either | |||
| a DataLoader object (data_loader) or a list of input data (X). If both data_loader | |||
| and X are provided, the method will predict the input data in data_loader | |||
| instead of X. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader, optional | |||
| The data loader used for prediction, by default None. | |||
| X : List[Any], optional | |||
| The input data, by default None. | |||
| Returns | |||
| ------- | |||
| numpy.ndarray | |||
| The predicted class of the input data. | |||
| """ | |||
| if data_loader is not None and X is not None: | |||
| print_log( | |||
| "Predict the class of input data in data_loader instead of X.", | |||
| logger="current", | |||
| level=logging.WARNING, | |||
| ) | |||
| if data_loader is None: | |||
| dataset = PredictionDataset(X, self.test_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| return self._predict(data_loader).argmax(axis=1).cpu().numpy() | |||
| def predict_proba( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| ) -> numpy.ndarray: | |||
| """ | |||
| Predict the probability of each class for the input data. This method supports | |||
| prediction with either a DataLoader object (data_loader) or a list of input data (X). | |||
| If both data_loader and X are provided, the method will predict the input data in | |||
| data_loader instead of X. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader, optional | |||
| The data loader used for prediction, by default None. | |||
| X : List[Any], optional | |||
| The input data, by default None. | |||
| Returns | |||
| ------- | |||
| numpy.ndarray | |||
| The predicted probability of each class for the input data. | |||
| """ | |||
| if data_loader is not None and X is not None: | |||
| print_log( | |||
| "Predict the class probability of input data in data_loader instead of X.", | |||
| logger="current", | |||
| level=logging.WARNING, | |||
| ) | |||
| if data_loader is None: | |||
| dataset = PredictionDataset(X, self.test_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| return self._predict(data_loader).softmax(axis=1).cpu().numpy() | |||
| def _score(self, data_loader: DataLoader) -> Tuple[float, float]: | |||
| """ | |||
| Internal method to compute loss and accuracy for the data provided through a DataLoader. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader | |||
| Data loader to use for evaluation. | |||
| Returns | |||
| ------- | |||
| Tuple[float, float] | |||
| mean_loss: float, The mean loss of the model on the provided data. | |||
| accuracy: float, The accuracy of the model on the provided data. | |||
| """ | |||
| if not isinstance(data_loader, DataLoader): | |||
| raise TypeError( | |||
| f"data_loader must be an instance of torch.utils.data.DataLoader, " | |||
| f"but got {type(data_loader)}" | |||
| ) | |||
| model = self.model | |||
| loss_fn = self.loss_fn | |||
| device = self.device | |||
| model.eval() | |||
| total_correct_num, total_num, total_loss = 0, 0, 0.0 | |||
| with torch.no_grad(): | |||
| for data, target in data_loader: | |||
| data, target = data.to(device), target.to(device) | |||
| out = model(data) | |||
| if len(out.shape) > 1: | |||
| correct_num = (target == out.argmax(axis=1)).sum().item() | |||
| else: | |||
| correct_num = (target == (out > 0.5)).sum().item() | |||
| loss = loss_fn(out, target) | |||
| total_loss += loss.item() * data.size(0) | |||
| total_correct_num += correct_num | |||
| total_num += data.size(0) | |||
| mean_loss = total_loss / total_num | |||
| accuracy = total_correct_num / total_num | |||
| return mean_loss, accuracy | |||
| def score( | |||
| self, | |||
| data_loader: Optional[DataLoader] = None, | |||
| X: Optional[List[Any]] = None, | |||
| y: Optional[List[int]] = None, | |||
| ) -> float: | |||
| """ | |||
| Validate the model. It supports validation with either a DataLoader object (data_loader) | |||
| or a pair of input data (X) and ground truth labels (y). If both data_loader and | |||
| (X, y) are provided, the method will prioritize using the data_loader. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader, optional | |||
| The data loader used for scoring, by default None. | |||
| X : List[Any], optional | |||
| The input data, by default None. | |||
| y : List[int], optional | |||
| The target data, by default None. | |||
| Returns | |||
| ------- | |||
| float | |||
| The accuracy of the model. | |||
| """ | |||
| print_log("Start machine learning model validation", logger="current") | |||
| if data_loader is not None and X is not None: | |||
| print_log( | |||
| "data_loader will be used to validate the model instead of X and y.", | |||
| logger="current", | |||
| level=logging.WARNING, | |||
| ) | |||
| if data_loader is None: | |||
| if X is None or y is None: | |||
| raise ValueError("data_loader and (X, y) can not be None simultaneously.") | |||
| else: | |||
| data_loader = self._data_loader(X, y) | |||
| mean_loss, accuracy = self._score(data_loader) | |||
| print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") | |||
| return accuracy | |||
| def _data_loader( | |||
| self, | |||
| X: Optional[List[Any]], | |||
| y: Optional[List[int]] = None, | |||
| shuffle: Optional[bool] = True, | |||
| ) -> DataLoader: | |||
| """ | |||
| Generate a DataLoader for user-provided input data and target labels. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| Input samples. | |||
| y : List[int], optional | |||
| Target labels. If None, dummy labels are created, by default None. | |||
| shuffle : bool, optional | |||
| Whether to shuffle the data, by default True. | |||
| Returns | |||
| ------- | |||
| DataLoader | |||
| A DataLoader providing batches of (X, y) pairs. | |||
| """ | |||
| if X is None: | |||
| raise ValueError("X should not be None.") | |||
| if y is None: | |||
| y = [0] * len(X) | |||
| if not (len(y) == len(X)): | |||
| raise ValueError("X and y should have equal length.") | |||
| dataset = ClassificationDataset(X, y, transform=self.train_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| shuffle=shuffle, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| return data_loader | |||
| def save(self, epoch_id: int = 0, save_path: Optional[str] = None) -> None: | |||
| """ | |||
| Save the model and the optimizer. User can either provide a save_path or specify | |||
| the epoch_id at which the model and optimizer is saved. if both save_path and | |||
| epoch_id are provided, save_path will be used. If only epoch_id is specified, | |||
| model and optimizer will be saved to the path f"model_checkpoint_epoch_{epoch_id}.pth" | |||
| under ``self.save_dir``. save_path and epoch_id can not be None simultaneously. | |||
| Parameters | |||
| ---------- | |||
| epoch_id : int | |||
| The epoch id. | |||
| save_path : str, optional | |||
| The path to save the model, by default None. | |||
| """ | |||
| if self.save_dir is None and save_path is None: | |||
| raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.") | |||
| if save_path is not None: | |||
| if not os.path.exists(os.path.dirname(save_path)): | |||
| os.makedirs(os.path.dirname(save_path)) | |||
| else: | |||
| save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth") | |||
| if not os.path.exists(self.save_dir): | |||
| os.makedirs(self.save_dir) | |||
| print_log(f"Checkpoints will be saved to {save_path}", logger="current") | |||
| save_parma_dic = { | |||
| "model": self.model.state_dict(), | |||
| "optimizer": self.optimizer.state_dict(), | |||
| } | |||
| torch.save(save_parma_dic, save_path) | |||
| def load(self, load_path: str) -> None: | |||
| """ | |||
| Load the model and the optimizer. | |||
| Parameters | |||
| ---------- | |||
| load_path : str | |||
| The directory to load the model, by default "". | |||
| """ | |||
| if load_path is None: | |||
| raise ValueError("Load path should not be None.") | |||
| print_log( | |||
| f"Loads checkpoint by local backend from path: {load_path}", | |||
| logger="current", | |||
| ) | |||
| param_dic = torch.load(load_path) | |||
| self.model.load_state_dict(param_dic["model"]) | |||
| if "optimizer" in param_dic.keys(): | |||
| self.optimizer.load_state_dict(param_dic["optimizer"]) | |||
| @@ -0,0 +1,211 @@ | |||
| import torch | |||
| import copy | |||
| from typing import Any, Callable, List, Optional | |||
| from .abl_model import ABLModel | |||
| from .basic_nn import BasicNN | |||
| from lambdaLearn.Base.DeepModelMixin import DeepModelMixin | |||
| class ModelConverter: | |||
| """ | |||
| This class provides functionality to convert LambdaLearn models to ABL-Package models. | |||
| """ | |||
| def __init__(self) -> None: | |||
| pass | |||
| def convert_lambdalearn_to_ablmodel( | |||
| self, | |||
| lambdalearn_model, | |||
| loss_fn: torch.nn.Module, | |||
| optimizer_dict: dict, | |||
| scheduler_dict: Optional[dict] = None, | |||
| device: Optional[torch.device] = None, | |||
| batch_size: int = 32, | |||
| num_epochs: int = 1, | |||
| stop_loss: Optional[float] = 0.0001, | |||
| num_workers: int = 0, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| train_transform: Callable[..., Any] = None, | |||
| test_transform: Callable[..., Any] = None, | |||
| collate_fn: Callable[[List[Any]], Any] = None, | |||
| ): | |||
| """ | |||
| Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of | |||
| DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn | |||
| model should implement ``fit`` and ``predict`` methods. | |||
| Parameters | |||
| ---------- | |||
| lambdalearn_model : Union[DeepModelMixin, Any] | |||
| The LambdaLearn model to be converted. | |||
| loss_fn : torch.nn.Module | |||
| The loss function used for training. | |||
| optimizer_dict : dict | |||
| The dict contains necessary parameters to construct a optimizer used for training. | |||
| The optimizer class is specified by the ``optimizer`` key. | |||
| scheduler_dict : dict, optional | |||
| The dict contains necessary parameters to construct a learning rate scheduler used | |||
| for training, which will be called at the end of each run of the ``fit`` method. | |||
| The scheduler class is specified by the ``scheduler`` key. It should implement the | |||
| ``step`` method, by default None. | |||
| device : torch.device, optional | |||
| The device on which the model will be trained or used for prediction, | |||
| by default torch.device("cpu"). | |||
| batch_size : int, optional | |||
| The batch size used for training, by default 32. | |||
| num_epochs : int, optional | |||
| The number of epochs used for training, by default 1. | |||
| stop_loss : float, optional | |||
| The loss value at which to stop training, by default 0.0001. | |||
| num_workers : int | |||
| The number of workers used for loading data, by default 0. | |||
| save_interval : int, optional | |||
| The model will be saved every ``save_interval`` epochs during training, by default None. | |||
| save_dir : str, optional | |||
| The directory in which to save the model during training, by default None. | |||
| train_transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version used | |||
| in the `fit` and `train_epoch` methods, by default None. | |||
| test_transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version in the | |||
| `predict`, `predict_proba` and `score` methods, , by default None. | |||
| collate_fn : Callable[[List[T]], Any], optional | |||
| The function used to collate data, by default None. | |||
| Returns | |||
| ------- | |||
| ABLModel | |||
| The converted ABLModel instance. | |||
| """ | |||
| if isinstance(lambdalearn_model, DeepModelMixin): | |||
| base_model = self.convert_lambdalearn_to_basicnn( | |||
| lambdalearn_model, | |||
| loss_fn, | |||
| optimizer_dict, | |||
| scheduler_dict, | |||
| device, | |||
| batch_size, | |||
| num_epochs, | |||
| stop_loss, | |||
| num_workers, | |||
| save_interval, | |||
| save_dir, | |||
| train_transform, | |||
| test_transform, | |||
| collate_fn, | |||
| ) | |||
| return ABLModel(base_model) | |||
| if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")): | |||
| raise NotImplementedError( | |||
| "The lambdalearn_model should be an instance of DeepModelMixin, or implement " | |||
| + "fit and predict methods." | |||
| ) | |||
| return ABLModel(lambdalearn_model) | |||
| def convert_lambdalearn_to_basicnn( | |||
| self, | |||
| lambdalearn_model: DeepModelMixin, | |||
| loss_fn: torch.nn.Module, | |||
| optimizer_dict: dict, | |||
| scheduler_dict: Optional[dict] = None, | |||
| device: Optional[torch.device] = None, | |||
| batch_size: int = 32, | |||
| num_epochs: int = 1, | |||
| stop_loss: Optional[float] = 0.0001, | |||
| num_workers: int = 0, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| train_transform: Callable[..., Any] = None, | |||
| test_transform: Callable[..., Any] = None, | |||
| collate_fn: Callable[[List[Any]], Any] = None, | |||
| ): | |||
| """ | |||
| Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of | |||
| DeepModelMixin, its network will be used as the model of BasicNN. | |||
| Parameters | |||
| ---------- | |||
| lambdalearn_model : Union[DeepModelMixin, Any] | |||
| The LambdaLearn model to be converted. | |||
| loss_fn : torch.nn.Module | |||
| The loss function used for training. | |||
| optimizer_dict : dict | |||
| The dict contains necessary parameters to construct a optimizer used for training. | |||
| scheduler_dict : dict, optional | |||
| The dict contains necessary parameters to construct a learning rate scheduler used | |||
| for training, which will be called at the end of each run of the ``fit`` method. | |||
| The scheduler class is specified by the ``scheduler`` key. It should implement the | |||
| ``step`` method, by default None. | |||
| device : torch.device, optional | |||
| The device on which the model will be trained or used for prediction, | |||
| by default torch.device("cpu"). | |||
| batch_size : int, optional | |||
| The batch size used for training, by default 32. | |||
| num_epochs : int, optional | |||
| The number of epochs used for training, by default 1. | |||
| stop_loss : float, optional | |||
| The loss value at which to stop training, by default 0.0001. | |||
| num_workers : int | |||
| The number of workers used for loading data, by default 0. | |||
| save_interval : int, optional | |||
| The model will be saved every ``save_interval`` epochs during training, by default None. | |||
| save_dir : str, optional | |||
| The directory in which to save the model during training, by default None. | |||
| train_transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version used | |||
| in the `fit` and `train_epoch` methods, by default None. | |||
| test_transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version in the | |||
| `predict`, `predict_proba` and `score` methods, , by default None. | |||
| collate_fn : Callable[[List[T]], Any], optional | |||
| The function used to collate data, by default None. | |||
| Returns | |||
| ------- | |||
| BasicNN | |||
| The converted BasicNN instance. | |||
| """ | |||
| if isinstance(lambdalearn_model, DeepModelMixin): | |||
| if not isinstance(lambdalearn_model.network, torch.nn.Module): | |||
| raise NotImplementedError( | |||
| "Expected lambdalearn_model.network to be a torch.nn.Module, " | |||
| + f"but got {type(lambdalearn_model.network)}" | |||
| ) | |||
| # Only use the network part and device of the lambdalearn model | |||
| network = copy.deepcopy(lambdalearn_model.network) | |||
| optimizer_class = optimizer_dict["optimizer"] | |||
| optimizer_dict.pop("optimizer") | |||
| optimizer = optimizer_class(network.parameters(), **optimizer_dict) | |||
| if scheduler_dict is not None: | |||
| scheduler_class = scheduler_dict["scheduler"] | |||
| scheduler_dict.pop("scheduler") | |||
| scheduler = scheduler_class(optimizer, **scheduler_dict) | |||
| else: | |||
| scheduler = None | |||
| device = lambdalearn_model.device if device is None else device | |||
| base_model = BasicNN( | |||
| model=network, | |||
| loss_fn=loss_fn, | |||
| optimizer=optimizer, | |||
| scheduler=scheduler, | |||
| device=device, | |||
| batch_size=batch_size, | |||
| num_epochs=num_epochs, | |||
| stop_loss=stop_loss, | |||
| num_workers=num_workers, | |||
| save_interval=save_interval, | |||
| save_dir=save_dir, | |||
| train_transform=train_transform, | |||
| test_transform=test_transform, | |||
| collate_fn=collate_fn, | |||
| ) | |||
| return base_model | |||
| else: | |||
| raise NotImplementedError( | |||
| "The lambdalearn_model should be an instance of DeepModelMixin." | |||
| ) | |||
| @@ -0,0 +1,9 @@ | |||
| from .classification_dataset import ClassificationDataset | |||
| from .prediction_dataset import PredictionDataset | |||
| from .regression_dataset import RegressionDataset | |||
| __all__ = [ | |||
| "ClassificationDataset", | |||
| "PredictionDataset", | |||
| "RegressionDataset", | |||
| ] | |||
| @@ -0,0 +1,66 @@ | |||
| from typing import Any, Callable, List, Tuple, Optional | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| class ClassificationDataset(Dataset): | |||
| """ | |||
| Dataset used for classification task. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| The input data. | |||
| Y : List[int] | |||
| The target data. | |||
| transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version. | |||
| Defaults to None. | |||
| """ | |||
| def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None): | |||
| if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
| raise ValueError("X and Y should be of type list.") | |||
| if len(X) != len(Y): | |||
| raise ValueError("Length of X and Y must be equal.") | |||
| self.X = X | |||
| self.Y = torch.LongTensor(Y) | |||
| self.transform = transform | |||
| def __len__(self) -> int: | |||
| """ | |||
| Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||
| """ | |||
| Get the item at the given index. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to get. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, torch.Tensor] | |||
| A tuple containing the object and its label. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| x = self.X[index] | |||
| if self.transform is not None: | |||
| x = self.transform(x) | |||
| y = self.Y[index] | |||
| return x, y | |||
| @@ -0,0 +1,58 @@ | |||
| from typing import Any, Callable, List, Tuple, Optional | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| class PredictionDataset(Dataset): | |||
| """ | |||
| Dataset used for prediction. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| The input data. | |||
| transform : Callable[..., Any], optional | |||
| A function/transform that takes an object and returns a transformed version. | |||
| Defaults to None. | |||
| """ | |||
| def __init__(self, X: List[Any], transform: Optional[Callable[..., Any]] = None): | |||
| if not isinstance(X, list): | |||
| raise ValueError("X should be of type list.") | |||
| self.X = X | |||
| self.transform = transform | |||
| def __len__(self) -> int: | |||
| """ | |||
| Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||
| """ | |||
| Get the item at the given index. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to get. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, torch.Tensor] | |||
| A tuple containing the object and its label. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| x = self.X[index] | |||
| if self.transform is not None: | |||
| x = self.transform(x) | |||
| return x | |||
| @@ -0,0 +1,56 @@ | |||
| from typing import Any, List, Tuple | |||
| from torch.utils.data import Dataset | |||
| class RegressionDataset(Dataset): | |||
| """ | |||
| Dataset used for regression task. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| A list of objects representing the input data. | |||
| Y : List[Any] | |||
| A list of objects representing the output data. | |||
| """ | |||
| def __init__(self, X: List[Any], Y: List[Any]): | |||
| if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
| raise ValueError("X and Y should be of type list.") | |||
| if len(X) != len(Y): | |||
| raise ValueError("Length of X and Y must be equal.") | |||
| self.X = X | |||
| self.Y = Y | |||
| def __len__(self): | |||
| """Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, Any]: | |||
| """Get an item from the dataset. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to retrieve. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, Any] | |||
| A tuple containing the input and output data at the specified index. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| x = self.X[index] | |||
| y = self.Y[index] | |||
| return x, y | |||
| @@ -0,0 +1,4 @@ | |||
| from .kb import GroundKB, KBBase, PrologKB | |||
| from .reasoner import Reasoner | |||
| __all__ = ["KBBase", "GroundKB", "PrologKB", "Reasoner"] | |||
| @@ -0,0 +1,622 @@ | |||
| import bisect | |||
| import inspect | |||
| import logging | |||
| import os | |||
| from abc import ABC, abstractmethod | |||
| from collections import defaultdict | |||
| from itertools import combinations, product | |||
| from multiprocessing import Pool | |||
| from typing import Any, Callable, List, Optional | |||
| import numpy as np | |||
| from ..utils.cache import abl_cache | |||
| from ..utils.logger import print_log | |||
| from ..utils.utils import flatten, hamming_dist, reform_list, to_hashable | |||
| class KBBase(ABC): | |||
| """ | |||
| Base class for knowledge base. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label_list : List[Any] | |||
| List of possible pseudo-labels. It's recommended to arrange the pseudo-labels in this | |||
| list so that each aligns with its corresponding index in the base model: the first with | |||
| the 0th index, the second with the 1st, and so forth. | |||
| max_err : float, optional | |||
| The upper tolerance limit when comparing the similarity between the reasoning result of | |||
| pseudo-labels and the ground truth. This is only applicable when the reasoning | |||
| result is of a numerical type. This is particularly relevant for regression problems where | |||
| exact matches might not be feasible. Defaults to 1e-10. | |||
| use_cache : bool, optional | |||
| Whether to use abl_cache for previously abduced candidates to speed up subsequent | |||
| operations. Defaults to True. | |||
| key_func : Callable, optional | |||
| A function employed for hashing in abl_cache. This is only operational when use_cache | |||
| is set to True. Defaults to ``to_hashable``. | |||
| cache_size: int, optional | |||
| The cache size in abl_cache. This is only operational when use_cache is set to | |||
| True. Defaults to 4096. | |||
| Notes | |||
| ----- | |||
| Users should derive from this base class to build their own knowledge base. For the | |||
| user-build KB (a derived subclass), it's only required for the user to provide the | |||
| ``pseudo_label_list`` and override the ``logic_forward`` function (specifying how to | |||
| perform logical reasoning). After that, other operations (e.g. how to perform abductive | |||
| reasoning) will be automatically set up. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list: List[Any], | |||
| max_err: float = 1e-10, | |||
| use_cache: bool = True, | |||
| key_func: Callable = to_hashable, | |||
| cache_size: int = 4096, | |||
| ): | |||
| if not isinstance(pseudo_label_list, list): | |||
| raise TypeError(f"pseudo_label_list should be list, got {type(pseudo_label_list)}") | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.max_err = max_err | |||
| self.use_cache = use_cache | |||
| self.key_func = key_func | |||
| self.cache_size = cache_size | |||
| argspec = inspect.getfullargspec(self.logic_forward) | |||
| self._num_args = len(argspec.args) - 1 | |||
| if ( | |||
| self._num_args == 2 and self.use_cache | |||
| ): # If the logic_forward function has 2 arguments, then disable cache | |||
| self.use_cache = False | |||
| print_log( | |||
| "The logic_forward function has 2 arguments, so the cache is disabled. ", | |||
| logger="current", | |||
| level=logging.WARNING, | |||
| ) | |||
| # TODO 添加半监督 | |||
| # TODO 添加consistency measure+max_err容忍错误 | |||
| @abstractmethod | |||
| def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any: | |||
| """ | |||
| How to perform (deductive) logical reasoning, i.e. matching pseudo-labels to | |||
| their reasoning result. Users are required to provide this. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example. | |||
| x : List[Any], optional | |||
| The example. If deductive logical reasoning does not require any | |||
| information from the example, the overridden function provided by the user can omit | |||
| this parameter. | |||
| Returns | |||
| ------- | |||
| Any | |||
| The reasoning result. | |||
| """ | |||
| def abduce_candidates( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| require_more_revision: int, | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| Perform abductive reasoning to get a candidate compatible with the knowledge base. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example (to be revised by abductive reasoning). | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The example. If the information from the example | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| max_revision_num : int | |||
| The upper limit on the number of revised labels for each example. | |||
| require_more_revision : int | |||
| Specifies additional number of revisions permitted beyond the minimum required. | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example. that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision) | |||
| def _check_equal(self, reasoning_result: Any, y: Any) -> bool: | |||
| """ | |||
| Check whether the reasoning result of a pseduo label example is equal to the ground truth | |||
| (or, within the maximum error allowed for numerical results). | |||
| Returns | |||
| ------- | |||
| bool | |||
| The result of the check. | |||
| """ | |||
| if reasoning_result is None: | |||
| return False | |||
| if isinstance(reasoning_result, (int, float)) and isinstance(y, (int, float)): | |||
| return abs(reasoning_result - y) <= self.max_err | |||
| else: | |||
| return reasoning_result == y | |||
| def revise_at_idx( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| Revise the pseudo-labels at specified index positions. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example (to be revised). | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The example. If the information from the example | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| revision_idx : List[int] | |||
| A list specifying indices of where revisions should be made to the pseudo-labels. | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| candidates, reasoning_results = [], [] | |||
| abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | |||
| for c in abduce_c: | |||
| candidate = pseudo_label.copy() | |||
| for i, idx in enumerate(revision_idx): | |||
| candidate[idx] = c[i] | |||
| reasoning_result = self.logic_forward(candidate, *(x,) if self._num_args == 2 else ()) | |||
| if self._check_equal(reasoning_result, y): | |||
| candidates.append(candidate) | |||
| reasoning_results.append(reasoning_result) | |||
| return candidates, reasoning_results | |||
| def _revision( | |||
| self, | |||
| revision_num: int, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| For a specified number of labels in a pseudo-labels to revise, iterate through | |||
| all possible indices to find any candidates that are compatible with the knowledge base. | |||
| """ | |||
| new_candidates, new_reasoning_results = [], [] | |||
| revision_idx_list = combinations(range(len(pseudo_label)), revision_num) | |||
| for revision_idx in revision_idx_list: | |||
| candidates, reasoning_results = self.revise_at_idx(pseudo_label, y, x, revision_idx) | |||
| new_candidates.extend(candidates) | |||
| new_reasoning_results.extend(reasoning_results) | |||
| return new_candidates, new_reasoning_results | |||
| @abl_cache() | |||
| def _abduce_by_search( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| require_more_revision: int, | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| Perform abductive reasoning by exhastive search. Specifically, begin with 0 and | |||
| continuously increase the number of labels to revise, until | |||
| candidates that are compatible with the knowledge base are found. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example (to be revised). | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The example. If the information from the example | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| max_revision_num : int | |||
| The upper limit on the number of revisions. | |||
| require_more_revision : int | |||
| If larger than 0, then after having found any candidates compatible with the | |||
| knowledge base, continue to increase the number of labels to | |||
| revise to get more possible compatible candidates. | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| candidates, reasoning_results = [], [] | |||
| for revision_num in range(len(pseudo_label) + 1): | |||
| new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x) | |||
| candidates.extend(new_candidates) | |||
| reasoning_results.extend(new_reasoning_results) | |||
| if len(candidates) > 0: | |||
| min_revision_num = revision_num | |||
| break | |||
| if revision_num >= max_revision_num: | |||
| return [], [] | |||
| for revision_num in range( | |||
| min_revision_num + 1, min_revision_num + require_more_revision + 1 | |||
| ): | |||
| if revision_num > max_revision_num: | |||
| return candidates, reasoning_results | |||
| new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x) | |||
| candidates.extend(new_candidates) | |||
| reasoning_results.extend(new_reasoning_results) | |||
| return candidates, reasoning_results | |||
| def __repr__(self): | |||
| return ( | |||
| f"{self.__class__.__name__} is a KB with " | |||
| f"pseudo_label_list={self.pseudo_label_list!r}, " | |||
| f"max_err={self.max_err!r}, " | |||
| f"use_cache={self.use_cache!r}." | |||
| ) | |||
| class GroundKB(KBBase): | |||
| """ | |||
| Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon | |||
| class initialization, storing all potential candidates along with their respective | |||
| reasoning result. Ground KB can accelerate abductive reasoning in ``abduce_candidates``. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label_list : List[Any] | |||
| Refer to class ``KBBase``. | |||
| GKB_len_list : List[int] | |||
| List of possible lengths for pseudo-labels of an example. | |||
| max_err : float, optional | |||
| Refer to class ``KBBase``. | |||
| Notes | |||
| ----- | |||
| Users can also inherit from this class to build their own knowledge base. Similar | |||
| to ``KBBase``, users are only required to provide the ``pseudo_label_list`` and override | |||
| the ``logic_forward`` function. Additionally, users should provide the ``GKB_len_list``. | |||
| After that, other operations (e.g. auto-construction of GKB, and how to perform | |||
| abductive reasoning) will be automatically set up. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list: List[Any], | |||
| GKB_len_list: List[int], | |||
| max_err: float = 1e-10, | |||
| ): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| if not isinstance(GKB_len_list, list): | |||
| raise TypeError("GKB_len_list should be list, but got {type(GKB_len_list)}") | |||
| if self._num_args == 2: | |||
| raise NotImplementedError( | |||
| "GroundKB only supports 1-argument logic_forward, but got " | |||
| + f"{self._num_args}-argument logic_forward" | |||
| ) | |||
| self.GKB_len_list = GKB_len_list | |||
| self.GKB = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| def _get_XY_list(self, args): | |||
| pre_x, post_x_it = args[0], args[1] | |||
| XY_list = [] | |||
| for post_x in post_x_it: | |||
| x = (pre_x,) + post_x | |||
| y = self.logic_forward(x) | |||
| if y is not None: | |||
| XY_list.append((x, y)) | |||
| return XY_list | |||
| def _get_GKB(self): | |||
| """ | |||
| Prebuild the GKB according to ``pseudo_label_list`` and ``GKB_len_list``. | |||
| """ | |||
| X, Y = [], [] | |||
| for length in self.GKB_len_list: | |||
| arg_list = [] | |||
| for pre_x in self.pseudo_label_list: | |||
| post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
| arg_list.append((pre_x, post_x_it)) | |||
| with Pool(processes=len(arg_list)) as pool: | |||
| ret_list = pool.map(self._get_XY_list, arg_list) | |||
| for XY_list in ret_list: | |||
| if len(XY_list) == 0: | |||
| continue | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| if Y and isinstance(Y[0], (int, float)): | |||
| X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||
| return X, Y | |||
| def abduce_candidates( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| max_revision_num: int, | |||
| require_more_revision: int, | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| Perform abductive reasoning by directly retrieving compatible candidates from | |||
| the prebuilt GKB. In this way, the time-consuming exhaustive search can be | |||
| avoided. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example (to be revised by abductive reasoning). | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The example (unused in GroundKB). | |||
| max_revision_num : int | |||
| The upper limit on the number of revised labels for each example. | |||
| require_more_revision : int | |||
| Specifies additional number of revisions permitted beyond the minimum required. | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list: | |||
| return [], [] | |||
| all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y) | |||
| if len(all_candidates) == 0: | |||
| return [], [] | |||
| cost_list = hamming_dist(pseudo_label, all_candidates) | |||
| min_revision_num = np.min(cost_list) | |||
| revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||
| idxs = np.where(cost_list <= revision_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| reasoning_results = [all_reasoning_results[idx] for idx in idxs] | |||
| return candidates, reasoning_results | |||
| def _find_candidate_GKB(self, pseudo_label: List[Any], y: Any) -> List[List[Any]]: | |||
| """ | |||
| Retrieve compatible candidates from the prebuilt GKB. For numerical reasoning results, | |||
| return all candidates and their corresponding reasoning results which fall within the | |||
| [y - max_err, y + max_err] range. | |||
| """ | |||
| if isinstance(y, (int, float)): | |||
| potential_candidates = self.GKB[len(pseudo_label)] | |||
| key_list = list(potential_candidates.keys()) | |||
| low_key = bisect.bisect_left(key_list, y - self.max_err) | |||
| high_key = bisect.bisect_right(key_list, y + self.max_err) | |||
| all_candidates, all_reasoning_results = [], [] | |||
| for key in key_list[low_key:high_key]: | |||
| for candidate in potential_candidates[key]: | |||
| all_candidates.append(candidate) | |||
| all_reasoning_results.append(key) | |||
| else: | |||
| all_candidates = self.GKB[len(pseudo_label)][y] | |||
| all_reasoning_results = [y] * len(all_candidates) | |||
| return all_candidates, all_reasoning_results | |||
| def __repr__(self): | |||
| GKB_info_parts = [] | |||
| for i in self.GKB_len_list: | |||
| num_candidates = len(self.GKB[i]) if i in self.GKB else 0 | |||
| GKB_info_parts.append(f"{num_candidates} candidates of length {i}") | |||
| GKB_info = ", ".join(GKB_info_parts) | |||
| return ( | |||
| f"{self.__class__.__name__} is a KB with " | |||
| f"pseudo_label_list={self.pseudo_label_list!r}, " | |||
| f"max_err={self.max_err!r}, " | |||
| f"use_cache={self.use_cache!r}. " | |||
| f"It has a prebuilt GKB with " | |||
| f"GKB_len_list={self.GKB_len_list!r}, " | |||
| f"and there are " | |||
| f"{GKB_info}" | |||
| f" in the GKB." | |||
| ) | |||
| class PrologKB(KBBase): | |||
| """ | |||
| Knowledge base provided by a Prolog (.pl) file. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label_list : List[Any] | |||
| Refer to class ``KBBase``. | |||
| pl_file : str | |||
| Prolog file containing the KB. | |||
| Notes | |||
| ----- | |||
| Users can instantiate this class to build their own knowledge base. During the | |||
| instantiation, users are only required to provide the ``pseudo_label_list`` and ``pl_file``. | |||
| To use the default logic forward and abductive reasoning methods in this class, in the | |||
| Prolog (.pl) file, there needs to be a rule which is strictly formatted as | |||
| ``logic_forward(Pseudo_labels, Res).``, e.g., ``logic_forward([A,B], C) :- C is A+B``. | |||
| For specifics, refer to the ``logic_forward`` and ``get_query_string`` functions in this | |||
| class. Users are also welcome to override related functions for more flexible support. | |||
| """ | |||
| def __init__(self, pseudo_label_list: List[Any], pl_file: str): | |||
| super().__init__(pseudo_label_list) | |||
| try: | |||
| import pyswip | |||
| except (IndexError, ImportError): | |||
| print( | |||
| "A Prolog-based knowledge base is in use. Please install Swi-Prolog using the" | |||
| + "command 'sudo apt-get install swi-prolog' for Linux users, or download it " | |||
| + "following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md " | |||
| + "for Windows and Mac users." | |||
| ) | |||
| self.prolog = pyswip.Prolog() | |||
| self.pl_file = pl_file | |||
| if not os.path.exists(self.pl_file): | |||
| raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") | |||
| self.prolog.consult(self.pl_file) | |||
| def logic_forward(self, pseudo_label: List[Any]) -> Any: | |||
| """ | |||
| Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the | |||
| returned ``Res`` as the reasoning results. To use this default function, there must be | |||
| a ``logic_forward`` method in the pl file to perform reasoning. | |||
| Otherwise, users would override this function. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example. | |||
| """ | |||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"] | |||
| if result == "true": | |||
| return True | |||
| elif result == "false": | |||
| return False | |||
| return result | |||
| def _revision_pseudo_label( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> List[Any]: | |||
| import re | |||
| revision_pseudo_label = pseudo_label.copy() | |||
| revision_pseudo_label = flatten(revision_pseudo_label) | |||
| for idx in revision_idx: | |||
| revision_pseudo_label[idx] = "P" + str(idx) | |||
| revision_pseudo_label = reform_list(revision_pseudo_label, pseudo_label) | |||
| regex = r"'P\d+'" | |||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label)) | |||
| def get_query_string( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> str: | |||
| """ | |||
| Get the query to be used for consulting Prolog. | |||
| This is a default function for demo, users would override this function to adapt to | |||
| their own Prolog file. In this demo function, return query | |||
| ``logic_forward([kept_labels, Revise_labels], Res).``. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example (to be revised by abductive reasoning). | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| revision_idx : List[int] | |||
| A list specifying indices of where revisions should be made to the pseudo-labels. | |||
| Returns | |||
| ------- | |||
| str | |||
| A string of the query. | |||
| """ | |||
| query_string = "logic_forward(" | |||
| query_string += self._revision_pseudo_label(pseudo_label, revision_idx) | |||
| key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None) | |||
| query_string += ",%s)." % y if not key_is_none_flag else ")." | |||
| return query_string | |||
| def revise_at_idx( | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| Revise the pseudo-labels at specified index positions by querying Prolog. | |||
| Parameters | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example (to be revised). | |||
| y : Any | |||
| Ground truth of the reasoning result for the example. | |||
| x : List[Any] | |||
| The corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| revision_idx : List[int] | |||
| A list specifying indices of where revisions should be made to the pseudo-labels. | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[Any]] | |||
| A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
| pseudo-labels of the example that are compatible with the knowledge base. The second | |||
| element is a list of reasoning results corresponding to each candidate, i.e., the | |||
| outcome of the ``logic_forward`` function. | |||
| """ | |||
| candidates, reasoning_results = [], [] | |||
| query_string = self.get_query_string(pseudo_label, y, x, revision_idx) | |||
| save_pseudo_label = pseudo_label | |||
| pseudo_label = flatten(pseudo_label) | |||
| abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] | |||
| for c in abduce_c: | |||
| candidate = pseudo_label.copy() | |||
| for i, idx in enumerate(revision_idx): | |||
| candidate[idx] = c[i] | |||
| candidate = reform_list(candidate, save_pseudo_label) | |||
| candidates.append(candidate) | |||
| reasoning_results.append(y) | |||
| return candidates, reasoning_results | |||
| def __repr__(self): | |||
| return ( | |||
| f"{self.__class__.__name__} is a KB with " | |||
| f"pseudo_label_list={self.pseudo_label_list!r}, " | |||
| f"defined by " | |||
| f"Prolog file {self.pl_file!r}." | |||
| ) | |||
| @@ -0,0 +1,351 @@ | |||
| import inspect | |||
| from typing import Any, Callable, List, Optional, Union | |||
| import numpy as np | |||
| from zoopt import Dimension, Objective, Opt, Parameter, Solution | |||
| from ..data.structures import ListData | |||
| from ..reasoning import KBBase | |||
| from ..utils.utils import confidence_dist, hamming_dist | |||
| class Reasoner: | |||
| """ | |||
| Reasoner for minimizing the inconsistency between the knowledge base and learning models. | |||
| Parameters | |||
| ---------- | |||
| kb : class KBBase | |||
| The knowledge base to be used for reasoning. | |||
| dist_func : Union[str, Callable], optional | |||
| The distance function used to determine the cost list between each | |||
| candidate and the given prediction. The cost is also referred to as a consistency | |||
| measure, wherein the candidate with lowest cost is selected as the final | |||
| abduced label. It can be either a string representing a predefined distance | |||
| function or a callable function. The available predefined distance functions: | |||
| 'hamming' | 'confidence'. 'hamming': directly calculates the Hamming | |||
| distance between the predicted pseudo-label in the data example and each | |||
| candidate, 'confidence': calculates the distance between the prediction | |||
| and each candidate based on confidence derived from the predicted probability | |||
| in the data example. The callable function should have the signature | |||
| dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must | |||
| return a cost list. Each element in this cost list should be a numerical value | |||
| representing the cost for each candidate, and the list should have the same length | |||
| as candidates. Defaults to 'confidence'. | |||
| idx_to_label : dict, optional | |||
| A mapping from index in the base model to label. If not provided, a default | |||
| order-based index to label mapping is created. Defaults to None. | |||
| max_revision : Union[int, float], optional | |||
| The upper limit on the number of revisions for each data example when | |||
| performing abductive reasoning. If float, denotes the fraction of the total | |||
| length that can be revised. A value of -1 implies no restriction on the | |||
| number of revisions. Defaults to -1. | |||
| require_more_revision : int, optional | |||
| Specifies additional number of revisions permitted beyond the minimum required | |||
| when performing abductive reasoning. Defaults to 0. | |||
| use_zoopt : bool, optional | |||
| Whether to use ZOOpt library during abductive reasoning. Defaults to False. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| kb: KBBase, | |||
| dist_func: Union[str, Callable] = "confidence", | |||
| idx_to_label: Optional[dict] = None, | |||
| max_revision: Union[int, float] = -1, | |||
| require_more_revision: int = 0, | |||
| use_zoopt: bool = False, | |||
| ): | |||
| self.kb = kb | |||
| self._check_valid_dist(dist_func) | |||
| self.dist_func = dist_func | |||
| self.use_zoopt = use_zoopt | |||
| self.max_revision = max_revision | |||
| self.require_more_revision = require_more_revision | |||
| if idx_to_label is None: | |||
| self.idx_to_label = { | |||
| index: label for index, label in enumerate(self.kb.pseudo_label_list) | |||
| } | |||
| else: | |||
| self._check_valid_idx_to_label(idx_to_label) | |||
| self.idx_to_label = idx_to_label | |||
| self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys())) | |||
| def _check_valid_dist(self, dist_func): | |||
| if isinstance(dist_func, str): | |||
| if dist_func not in ["hamming", "confidence"]: | |||
| raise NotImplementedError( | |||
| 'Valid options for predefined dist_func include "hamming" ' | |||
| + f'and "confidence", but got {dist_func}.' | |||
| ) | |||
| return | |||
| elif callable(dist_func): | |||
| params = inspect.signature(dist_func).parameters.values() | |||
| if len(params) != 4: | |||
| raise ValueError( | |||
| "User-defined dist_func must have exactly four parameters, " | |||
| + f"but got {len(params)}." | |||
| ) | |||
| return | |||
| else: | |||
| raise TypeError( | |||
| f"dist_func must be a string or a callable function, but got {type(dist_func)}." | |||
| ) | |||
| def _check_valid_idx_to_label(self, idx_to_label): | |||
| if not isinstance(idx_to_label, dict): | |||
| raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.") | |||
| for key, value in idx_to_label.items(): | |||
| if not isinstance(key, int): | |||
| raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.") | |||
| if value not in self.kb.pseudo_label_list: | |||
| raise ValueError( | |||
| "All values in the idx_to_label must be in the pseudo_label_list, " | |||
| + f"but got {value}." | |||
| ) | |||
| def _get_one_candidate( | |||
| self, | |||
| data_example: ListData, | |||
| candidates: List[List[Any]], | |||
| reasoning_results: List[Any], | |||
| ) -> List[Any]: | |||
| """ | |||
| Due to the nondeterminism of abductive reasoning, there could be multiple candidates | |||
| satisfying the knowledge base. When this happens, return one candidate that has the | |||
| minimum cost. If no candidates are provided, an empty list is returned. | |||
| Parameters | |||
| ---------- | |||
| data_example : ListData | |||
| Data example. | |||
| candidates : List[List[Any]] | |||
| Multiple possible candidates. | |||
| reasoning_results : List[Any] | |||
| Corresponding reasoning results of the candidates. | |||
| Returns | |||
| ------- | |||
| List[Any] | |||
| A selected candidate. | |||
| """ | |||
| if len(candidates) == 0: | |||
| return [] | |||
| elif len(candidates) == 1: | |||
| return candidates[0] | |||
| else: | |||
| cost_array = self._get_cost_list(data_example, candidates, reasoning_results) | |||
| candidate = candidates[np.argmin(cost_array)] | |||
| return candidate | |||
| def _get_cost_list( | |||
| self, | |||
| data_example: ListData, | |||
| candidates: List[List[Any]], | |||
| reasoning_results: List[Any], | |||
| ) -> Union[List[Union[int, float]], np.ndarray]: | |||
| """ | |||
| Get the list of costs between each candidate and the given data example. | |||
| Parameters | |||
| ---------- | |||
| data_example : ListData | |||
| Data example. | |||
| candidates : List[List[Any]] | |||
| Multiple possible candidates. | |||
| reasoning_results : List[Any] | |||
| Corresponding reasoning results of the candidates. | |||
| Returns | |||
| ------- | |||
| Union[List[Union[int, float]], np.ndarray] | |||
| The list of costs. | |||
| """ | |||
| if self.dist_func == "hamming": | |||
| return hamming_dist(data_example.pred_pseudo_label, candidates) | |||
| elif self.dist_func == "confidence": | |||
| candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | |||
| return confidence_dist(data_example.pred_prob, candidates_idxs) | |||
| else: | |||
| candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | |||
| cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) | |||
| if len(cost_list) != len(candidates): | |||
| raise ValueError( | |||
| "The length of the array returned by dist_func must be equal to the number " | |||
| + f"of candidates. Expected length {len(candidates)}, but got {len(cost_list)}." | |||
| ) | |||
| return cost_list | |||
| def _zoopt_get_solution( | |||
| self, | |||
| symbol_num: int, | |||
| data_example: ListData, | |||
| max_revision_num: int, | |||
| ) -> Solution: | |||
| """ | |||
| Get the optimal solution using ZOOpt library. From the solution, we can get a list of | |||
| boolean values, where '1' (True) indicates the indices chosen to be revised. | |||
| Parameters | |||
| ---------- | |||
| symbol_num : int | |||
| Number of total symbols. | |||
| data_example : ListData | |||
| Data example. | |||
| max_revision_num : int | |||
| Specifies the maximum number of revisions allowed. | |||
| Returns | |||
| ------- | |||
| Solution | |||
| The solution for ZOOpt library. | |||
| """ | |||
| dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | |||
| objective = Objective( | |||
| lambda sol: self.zoopt_score(symbol_num, data_example, sol), | |||
| dim=dimension, | |||
| constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
| ) | |||
| parameter = Parameter( | |||
| budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True | |||
| ) | |||
| solution = Opt.min(objective, parameter) | |||
| return solution | |||
| def zoopt_score( | |||
| self, | |||
| symbol_num: int, | |||
| data_example: ListData, | |||
| sol: Solution, | |||
| ) -> int: | |||
| """ | |||
| Set the score for a solution. A lower score suggests that ZOOpt library | |||
| has a higher preference for this solution. | |||
| Parameters | |||
| ---------- | |||
| symbol_num : int | |||
| Number of total symbols. | |||
| data_example : ListData | |||
| Data example. | |||
| sol: Solution | |||
| The solution for ZOOpt library. | |||
| Returns | |||
| ------- | |||
| int | |||
| The score for the solution. | |||
| """ | |||
| revision_idx = np.where(sol.get_x() != 0)[0] | |||
| candidates, reasoning_results = self.kb.revise_at_idx( | |||
| data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx | |||
| ) | |||
| if len(candidates) > 0: | |||
| return np.min(self._get_cost_list(data_example, candidates, reasoning_results)) | |||
| else: | |||
| return symbol_num | |||
| def zoopt_budget(self, symbol_num: int) -> int: | |||
| """ | |||
| Set the budget for ZOOpt optimization. The function, in its default implementation, | |||
| returns a fixed budget value of 100. However, it can be adjusted to return other fixed | |||
| values, or a dynamic budget based on the number of symbols, if desired. For example, | |||
| one might choose to set the budget as 100 times ``symbol_num``. | |||
| Parameters | |||
| ---------- | |||
| symbol_num : int | |||
| The number of symbols to be considered in the ZOOpt optimization process. Although this | |||
| parameter can be used to compute a dynamic optimization budget, by default it is not | |||
| utilized in the calculation. | |||
| Returns | |||
| ------- | |||
| int | |||
| The budget for ZOOpt optimization. By default, this is a fixed value of 100, | |||
| irrespective of the symbol_num value. | |||
| """ | |||
| return 100 | |||
| def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: | |||
| """ | |||
| Constrain that the total number of revisions chosen by the solution does not exceed | |||
| maximum number of revisions allowed. | |||
| """ | |||
| x = solution.get_x() | |||
| return max_revision_num - x.sum() | |||
| def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int: | |||
| """ | |||
| Get the maximum revision number according to input ``max_revision``. | |||
| """ | |||
| if not isinstance(max_revision, (int, float)): | |||
| raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}") | |||
| if max_revision == -1: | |||
| return symbol_num | |||
| elif isinstance(max_revision, float): | |||
| if not (0 <= max_revision <= 1): | |||
| raise ValueError( | |||
| "If max_revision is a float, it must be between 0 and 1, " | |||
| + f"but got {max_revision}" | |||
| ) | |||
| return round(symbol_num * max_revision) | |||
| else: | |||
| if max_revision < 0: | |||
| raise ValueError( | |||
| f"If max_revision is an int, it must be non-negative, but got {max_revision}" | |||
| ) | |||
| return max_revision | |||
| def abduce(self, data_example: ListData) -> List[Any]: | |||
| """ | |||
| Perform abductive reasoning on the given data example. | |||
| Parameters | |||
| ---------- | |||
| data_example : ListData | |||
| Data example. | |||
| Returns | |||
| ------- | |||
| List[Any] | |||
| A revised pseudo-labels of the example through abductive reasoning, which is compatible | |||
| with the knowledge base. | |||
| """ | |||
| symbol_num = data_example.elements_num("pred_pseudo_label") | |||
| max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||
| if self.use_zoopt: | |||
| solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | |||
| revision_idx = np.where(solution.get_x() != 0)[0] | |||
| candidates, reasoning_results = self.kb.revise_at_idx( | |||
| pseudo_label=data_example.pred_pseudo_label, | |||
| y=data_example.Y, | |||
| x=data_example.X, | |||
| revision_idx=revision_idx, | |||
| ) | |||
| else: | |||
| candidates, reasoning_results = self.kb.abduce_candidates( | |||
| pseudo_label=data_example.pred_pseudo_label, | |||
| y=data_example.Y, | |||
| x=data_example.X, | |||
| max_revision_num=max_revision_num, | |||
| require_more_revision=self.require_more_revision, | |||
| ) | |||
| candidate = self._get_one_candidate(data_example, candidates, reasoning_results) | |||
| return candidate | |||
| def batch_abduce(self, data_examples: ListData) -> List[List[Any]]: | |||
| """ | |||
| Perform abductive reasoning on the given prediction data examples. | |||
| For detailed information, refer to ``abduce``. | |||
| """ | |||
| abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples] | |||
| data_examples.abduced_pseudo_label = abduced_pseudo_label | |||
| return abduced_pseudo_label | |||
| def __call__(self, data_examples: ListData) -> List[List[Any]]: | |||
| return self.batch_abduce(data_examples) | |||
| @@ -0,0 +1,23 @@ | |||
| from .cache import Cache, abl_cache | |||
| from .logger import ABLLogger, print_log | |||
| from .utils import ( | |||
| confidence_dist, | |||
| flatten, | |||
| hamming_dist, | |||
| reform_list, | |||
| to_hashable, | |||
| tab_data_to_tuple, | |||
| ) | |||
| __all__ = [ | |||
| "Cache", | |||
| "ABLLogger", | |||
| "print_log", | |||
| "confidence_dist", | |||
| "flatten", | |||
| "hamming_dist", | |||
| "reform_list", | |||
| "to_hashable", | |||
| "abl_cache", | |||
| "tab_data_to_tuple", | |||
| ] | |||
| @@ -0,0 +1,99 @@ | |||
| from typing import Callable, Generic, TypeVar | |||
| K = TypeVar("K") | |||
| T = TypeVar("T") | |||
| PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | |||
| class Cache(Generic[K, T]): | |||
| def __init__(self, func: Callable[[K], T]): | |||
| """Create cache | |||
| :param func: Function this cache evaluates | |||
| :param cache: If true, do in memory caching. | |||
| :param cache_root: If not None, cache to files at the provided path. | |||
| :param key_func: Convert the key into a hashable object if needed | |||
| """ | |||
| self.func = func | |||
| self.has_init = False | |||
| def __getitem__(self, obj, *args) -> T: | |||
| return self.get_from_dict(obj, *args) | |||
| def clear_cache(self): | |||
| """Invalidate entire cache.""" | |||
| self.cache_dict.clear() | |||
| def _init_cache(self, obj): | |||
| if self.has_init: | |||
| return | |||
| self.cache = True | |||
| self.cache_dict = dict() | |||
| self.key_func = obj.key_func | |||
| self.max_size = obj.cache_size | |||
| self.hits, self.misses = 0, 0 | |||
| self.full = False | |||
| self.root = [] # root of the circular doubly linked list | |||
| self.root[:] = [self.root, self.root, None, None] | |||
| self.has_init = True | |||
| def get_from_dict(self, obj, *args) -> T: | |||
| """Implements dict based cache.""" | |||
| # x is not used in cache key | |||
| pred_pseudo_label, y, x, *res_args = args | |||
| cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) | |||
| link = self.cache_dict.get(cache_key) | |||
| if link is not None: | |||
| # Move the link to the front of the circular queue | |||
| link_prev, link_next, _key, result = link | |||
| link_prev[NEXT] = link_next | |||
| link_next[PREV] = link_prev | |||
| last = self.root[PREV] | |||
| last[NEXT] = self.root[PREV] = link | |||
| link[PREV] = last | |||
| link[NEXT] = self.root | |||
| self.hits += 1 | |||
| return result | |||
| self.misses += 1 | |||
| result = self.func(obj, *args) | |||
| if self.full: | |||
| # Use the old root to store the new key and result. | |||
| oldroot = self.root | |||
| oldroot[KEY] = cache_key | |||
| oldroot[RESULT] = result | |||
| # Empty the oldest link and make it the new root. | |||
| self.root = oldroot[NEXT] | |||
| oldkey = self.root[KEY] | |||
| self.root[KEY] = self.root[RESULT] = None | |||
| # Now update the cache dictionary. | |||
| del self.cache_dict[oldkey] | |||
| self.cache_dict[cache_key] = oldroot | |||
| else: | |||
| # Put result in a new link at the front of the queue. | |||
| last = self.root[PREV] | |||
| link = [last, self.root, cache_key, result] | |||
| last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link | |||
| if isinstance(self.max_size, int): | |||
| self.full = len(self.cache_dict) >= self.max_size | |||
| return result | |||
| def abl_cache(): | |||
| def decorator(func): | |||
| cache_instance = Cache(func) | |||
| def wrapper(obj, *args): | |||
| if obj.use_cache: | |||
| cache_instance._init_cache(obj) | |||
| return cache_instance.get_from_dict(obj, *args) | |||
| else: | |||
| return func(obj, *args) | |||
| return wrapper | |||
| return decorator | |||
| @@ -0,0 +1,344 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import logging | |||
| import os | |||
| import os.path as osp | |||
| import sys | |||
| from logging import Logger, LogRecord | |||
| from typing import Optional, Union | |||
| from termcolor import colored | |||
| from .manager import ManagerMixin, _accquire_lock, _release_lock | |||
| class FilterDuplicateWarning(logging.Filter): | |||
| """ | |||
| Filter for eliminating repeated warning messages in logging. | |||
| This filter checks for duplicate warning messages and allows only the first occurrence of | |||
| each message to be logged, filtering out subsequent duplicates. | |||
| Parameters | |||
| ---------- | |||
| name : str, optional | |||
| The name of the filter, by default "abl". | |||
| """ | |||
| def __init__(self, name: Optional[str] = "abl"): | |||
| super().__init__(name) | |||
| self.seen: set = set() | |||
| def filter(self, record: LogRecord) -> bool: | |||
| """Filter the repeated warning message. | |||
| Args: | |||
| record (LogRecord): The log record. | |||
| Returns: | |||
| bool: Whether to output the log record. | |||
| """ | |||
| if record.levelno != logging.WARNING: | |||
| return True | |||
| if record.msg not in self.seen: | |||
| self.seen.add(record.msg) | |||
| return True | |||
| return False | |||
| class ABLFormatter(logging.Formatter): | |||
| """ | |||
| Colorful format for ABLLogger. If the log level is error, the logger will | |||
| additionally output the location of the code. | |||
| Parameters | |||
| ---------- | |||
| color : bool, optional | |||
| Whether to use colorful format. filehandler is not | |||
| allowed to use color format, otherwise it will be garbled. | |||
| Defaults to True. | |||
| blink : bool, optional | |||
| Whether to blink the ``INFO`` and ``DEBUG`` logging | |||
| level. Defaults to False. | |||
| kwargs : dict | |||
| Keyword arguments passed to | |||
| :meth:``logging.Formatter.__init__``. | |||
| """ | |||
| _color_mapping: dict = dict(ERROR="red", WARNING="yellow", INFO="white", DEBUG="green") | |||
| def __init__(self, color: bool = True, blink: bool = False, **kwargs): | |||
| super().__init__(**kwargs) | |||
| assert not (not color and blink), "blink should only be available when color is True" | |||
| # Get prefix format according to color. | |||
| error_prefix = self._get_prefix("ERROR", color, blink=True) | |||
| warn_prefix = self._get_prefix("WARNING", color, blink=True) | |||
| info_prefix = self._get_prefix("INFO", color, blink) | |||
| debug_prefix = self._get_prefix("DEBUG", color, blink) | |||
| # Config output format. | |||
| self.err_format = ( | |||
| f"%(asctime)s - %(name)s - {error_prefix} - " | |||
| "%(pathname)s - %(funcName)s - %(lineno)d - " | |||
| "%(message)s" | |||
| ) | |||
| self.warn_format = f"%(asctime)s - %(name)s - {warn_prefix} - %(" "message)s" | |||
| self.info_format = f"%(asctime)s - %(name)s - {info_prefix} - %(" "message)s" | |||
| self.debug_format = f"%(asctime)s - %(name)s - {debug_prefix} - %(" "message)s" | |||
| def _get_prefix(self, level: str, color: bool, blink: bool = False) -> str: | |||
| """ | |||
| Get the prefix of the target log level. | |||
| Parameters | |||
| ---------- | |||
| level : str | |||
| Log level. | |||
| color : bool | |||
| Whether to get a colorful prefix. | |||
| blink : bool, optional | |||
| Whether the prefix will blink. Defaults to False. | |||
| Returns | |||
| ------- | |||
| str | |||
| The plain or colorful prefix. | |||
| """ | |||
| if color: | |||
| attrs = ["underline"] | |||
| if blink: | |||
| attrs.append("blink") | |||
| prefix = colored(level, self._color_mapping[level], attrs=attrs) | |||
| else: | |||
| prefix = level | |||
| return prefix | |||
| def format(self, record: LogRecord) -> str: | |||
| """ | |||
| Override the ``logging.Formatter.format`` method. Output the | |||
| message according to the specified log level. | |||
| Parameters | |||
| ---------- | |||
| record : LogRecord | |||
| A LogRecord instance representing an event being logged. | |||
| Returns | |||
| ------- | |||
| str | |||
| Formatted result. | |||
| """ | |||
| if record.levelno == logging.ERROR: | |||
| self._style._fmt = self.err_format | |||
| elif record.levelno == logging.WARNING: | |||
| self._style._fmt = self.warn_format | |||
| elif record.levelno == logging.INFO: | |||
| self._style._fmt = self.info_format | |||
| elif record.levelno == logging.DEBUG: | |||
| self._style._fmt = self.debug_format | |||
| result = logging.Formatter.format(self, record) | |||
| return result | |||
| class ABLLogger(Logger, ManagerMixin): | |||
| """ | |||
| Formatted logger used to record messages with different log levels and features. | |||
| ``ABLLogger`` provides a formatted logger that can log messages with different | |||
| log levels. It allows the creation of logger instances in a similar manner to ``ManagerMixin``. | |||
| The logger has features like distributed log storage and colored terminal output for different | |||
| log levels. | |||
| Parameters | |||
| ---------- | |||
| name : str | |||
| Global instance name. | |||
| logger_name : str, optional | |||
| ``name`` attribute of ``logging.Logger`` instance. Defaults to 'abl'. | |||
| log_file : str, optional | |||
| The log filename. If specified, a ``FileHandler`` will be added to the logger. | |||
| Defaults to None. | |||
| log_level : Union[int, str], optional | |||
| The log level of the handler. Defaults to 'INFO'. | |||
| If log level is 'DEBUG', distributed logs will be saved during distributed training. | |||
| file_mode : str, optional | |||
| The file mode used to open log file. Defaults to 'w'. | |||
| Notes | |||
| ----- | |||
| - The ``name`` of the logger and the ``instance_name`` of ``ABLLogger`` could be different. | |||
| ``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not | |||
| ``logging.getLogger``. This ensures ``ABLLogger`` is not influenced by third-party logging | |||
| configurations. | |||
| - Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without | |||
| ``Handler``. | |||
| Examples | |||
| -------- | |||
| >>> logger = ABLLogger.get_instance(name='ABLLogger', logger_name='Logger') | |||
| >>> # Although logger has a name attribute like ``logging.Logger`` | |||
| >>> # We cannot get logger instance by ``logging.getLogger``. | |||
| >>> assert logger.name == 'Logger' | |||
| >>> assert logger.instance_name == 'ABLLogger' | |||
| >>> assert id(logger) != id(logging.getLogger('Logger')) | |||
| >>> # Get logger that does not store logs. | |||
| >>> logger1 = ABLLogger.get_instance('logger1') | |||
| >>> # Get logger only save rank0 logs. | |||
| >>> logger2 = ABLLogger.get_instance('logger2', log_file='out.log') | |||
| >>> # Get logger only save multiple ranks logs. | |||
| >>> logger3 = ABLLogger.get_instance('logger3', log_file='out.log', distributed=True) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| name: str, | |||
| logger_name="abl", | |||
| log_file: Optional[str] = None, | |||
| log_level: Union[int, str] = "INFO", | |||
| file_mode: str = "w", | |||
| ): | |||
| Logger.__init__(self, logger_name) | |||
| ManagerMixin.__init__(self, name) | |||
| if isinstance(log_level, str): | |||
| log_level = logging._nameToLevel[log_level] | |||
| stream_handler = logging.StreamHandler(stream=sys.stdout) | |||
| # ``StreamHandler`` record month, day, hour, minute, and second | |||
| # timestamp. | |||
| stream_handler.setFormatter(ABLFormatter(color=True, datefmt="%m/%d %H:%M:%S")) | |||
| stream_handler.setLevel(log_level) | |||
| stream_handler.addFilter(FilterDuplicateWarning(logger_name)) | |||
| self.handlers.append(stream_handler) | |||
| if log_file is None: | |||
| import time | |||
| local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) | |||
| _log_dir = os.path.join("results", local_time) | |||
| self._log_dir = _log_dir | |||
| if not os.path.exists(_log_dir): | |||
| os.makedirs(_log_dir) | |||
| log_file = osp.join(_log_dir, local_time + ".log") | |||
| file_handler = logging.FileHandler(log_file, file_mode) | |||
| file_handler.setFormatter(ABLFormatter(color=False, datefmt="%Y/%m/%d %H:%M:%S")) | |||
| file_handler.setLevel(log_level) | |||
| file_handler.addFilter(FilterDuplicateWarning(logger_name)) | |||
| self.handlers.append(file_handler) | |||
| self._log_file = log_file | |||
| @property | |||
| def log_file(self): | |||
| return self._log_file | |||
| @property | |||
| def log_dir(self): | |||
| return self._log_dir | |||
| @classmethod | |||
| def get_current_instance(cls) -> "ABLLogger": | |||
| """ | |||
| Get the latest created ``ABLLogger`` instance. | |||
| Returns | |||
| ------- | |||
| ABLLogger | |||
| The latest created ``ABLLogger`` instance. If no instance has been created, | |||
| returns a logger with the instance name "abl". | |||
| """ | |||
| if not cls._instance_dict: | |||
| cls.get_instance("abl") | |||
| return super().get_current_instance() | |||
| def callHandlers(self, record: LogRecord) -> None: | |||
| """ | |||
| Pass a record to all relevant handlers. | |||
| Override the ``callHandlers`` method in ``logging.Logger`` to avoid | |||
| multiple warning messages in DDP mode. This method loops through all | |||
| handlers of the logger instance and its parents in the logger hierarchy. | |||
| Parameters | |||
| ---------- | |||
| record : LogRecord | |||
| A ``LogRecord`` instance containing the logged message. | |||
| """ | |||
| for handler in self.handlers: | |||
| if record.levelno >= handler.level: | |||
| handler.handle(record) | |||
| def setLevel(self, level): | |||
| """ | |||
| Set the logging level of this logger. | |||
| Override the ``setLevel`` method to clear caches of all ``ABLLogger`` instances | |||
| managed by ``ManagerMixin``. The level must be an int or a str. | |||
| Parameters | |||
| ---------- | |||
| level : Union[int, str] | |||
| The logging level to set. | |||
| """ | |||
| self.level = logging._checkLevel(level) | |||
| _accquire_lock() | |||
| # The same logic as ``logging.Manager._clear_cache``. | |||
| for logger in ABLLogger._instance_dict.values(): | |||
| logger._cache.clear() | |||
| _release_lock() | |||
| def print_log( | |||
| msg, | |||
| logger: Optional[Union[Logger, str]] = None, | |||
| level: Optional[int] = logging.INFO, | |||
| ) -> None: | |||
| """ | |||
| Print a log message using the specified logger or a default method. | |||
| This function logs a message with a given logger, if provided, or prints it using | |||
| the standard ``print`` function. It supports special logger types such as 'silent' | |||
| and 'current'. | |||
| Parameters | |||
| ---------- | |||
| msg : str | |||
| The message to be logged. | |||
| logger : Union[Logger, str], optional | |||
| The logger to use for logging the message. It can be a ``logging.Logger`` instance, a string | |||
| specifying the logger name, 'silent', 'current', or None. If None, the ``print`` | |||
| method is used. | |||
| - 'silent': No message will be printed. | |||
| - 'current': Use the latest created logger to log the message. | |||
| - other str: The instance name of the logger. A ``ValueError`` is raised if the logger has | |||
| not been created. | |||
| - None: The ``print()`` method is used for logging. | |||
| level : int, optional | |||
| The logging level. This is only applicable when ``logger`` is a Logger object, 'current', | |||
| or a named logger instance. The default is ``logging.INFO``. | |||
| """ | |||
| if logger is None: | |||
| print(msg) | |||
| elif isinstance(logger, logging.Logger): | |||
| logger.log(level, msg) | |||
| elif logger == "silent": | |||
| pass | |||
| elif logger == "current": | |||
| logger_instance = ABLLogger.get_current_instance() | |||
| logger_instance.log(level, msg) | |||
| elif isinstance(logger, str): | |||
| # If the type of ``logger`` is ``str``, but not with value of ``current`` or | |||
| # ``silent``, we assume it indicates the name of the logger. If the | |||
| # corresponding logger has not been created, ``print_log`` will raise | |||
| # a ``ValueError``. | |||
| if ABLLogger.check_instance_created(logger): | |||
| logger_instance = ABLLogger.get_instance(logger) | |||
| logger_instance.log(level, msg) | |||
| else: | |||
| raise ValueError(f"ABLLogger: {logger} has not been created!") | |||
| else: | |||
| raise TypeError( | |||
| "``logger`` should be either a logging.Logger object, str, " | |||
| f'"silent", "current" or None, but got {type(logger)}' | |||
| ) | |||
| @@ -0,0 +1,169 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import inspect | |||
| import threading | |||
| import warnings | |||
| from collections import OrderedDict | |||
| from typing import Type, TypeVar | |||
| _lock = threading.RLock() | |||
| T = TypeVar("T") | |||
| def _accquire_lock() -> None: | |||
| """Acquire the module-level lock for serializing access to shared data. | |||
| This should be released with _release_lock(). | |||
| """ | |||
| if _lock: | |||
| _lock.acquire() | |||
| def _release_lock() -> None: | |||
| """Release the module-level lock acquired by calling _accquire_lock().""" | |||
| if _lock: | |||
| _lock.release() | |||
| class ManagerMeta(type): | |||
| """The metaclass for global accessible class. | |||
| The subclasses inheriting from ``ManagerMeta`` will manage their | |||
| own ``_instance_dict`` and root instances. The constructors of subclasses | |||
| must contain the ``name`` argument. | |||
| Examples: | |||
| >>> class SubClass1(metaclass=ManagerMeta): | |||
| >>> def __init__(self, *args, **kwargs): | |||
| >>> pass | |||
| AssertionError: <class '__main__.SubClass1'>.__init__ must have the | |||
| name argument. | |||
| >>> class SubClass2(metaclass=ManagerMeta): | |||
| >>> def __init__(self, name): | |||
| >>> pass | |||
| >>> # valid format. | |||
| """ | |||
| def __init__(cls, *args): | |||
| cls._instance_dict = OrderedDict() | |||
| params = inspect.getfullargspec(cls) | |||
| params_names = params[0] if params[0] else [] | |||
| assert "name" in params_names, f"{cls} must have the `name` argument" | |||
| super().__init__(*args) | |||
| class ManagerMixin(metaclass=ManagerMeta): | |||
| """``ManagerMixin`` is the base class for classes that have global access | |||
| requirements. | |||
| The subclasses inheriting from ``ManagerMixin`` can get their | |||
| global instances. | |||
| Examples: | |||
| >>> class GlobalAccessible(ManagerMixin): | |||
| >>> def __init__(self, name=''): | |||
| >>> super().__init__(name) | |||
| >>> | |||
| >>> GlobalAccessible.get_instance('name') | |||
| >>> instance_1 = GlobalAccessible.get_instance('name') | |||
| >>> instance_2 = GlobalAccessible.get_instance('name') | |||
| >>> assert id(instance_1) == id(instance_2) | |||
| Args: | |||
| name (str): Name of the instance. Defaults to ''. | |||
| """ | |||
| def __init__(self, name: str = "", **kwargs): | |||
| assert isinstance(name, str) and name, "name argument must be an non-empty string." | |||
| self._instance_name = name | |||
| @classmethod | |||
| def get_instance(cls: Type[T], name: str, **kwargs) -> T: | |||
| """Get subclass instance by name if the name exists. | |||
| If corresponding name instance has not been created, ``get_instance`` | |||
| will create an instance, otherwise ``get_instance`` will return the | |||
| corresponding instance. | |||
| Examples | |||
| >>> instance1 = GlobalAccessible.get_instance('name1') | |||
| >>> # Create name1 instance. | |||
| >>> instance.instance_name | |||
| name1 | |||
| >>> instance2 = GlobalAccessible.get_instance('name1') | |||
| >>> # Get name1 instance. | |||
| >>> assert id(instance1) == id(instance2) | |||
| Args: | |||
| name (str): Name of instance. Defaults to ''. | |||
| Returns: | |||
| object: Corresponding name instance, the latest instance, or root | |||
| instance. | |||
| """ | |||
| _accquire_lock() | |||
| assert isinstance(name, str), f"type of name should be str, but got {type(cls)}" | |||
| instance_dict = cls._instance_dict # type: ignore | |||
| # Get the instance by name. | |||
| if name not in instance_dict: | |||
| instance = cls(name=name, **kwargs) # type: ignore | |||
| instance_dict[name] = instance # type: ignore | |||
| elif kwargs: | |||
| warnings.warn( | |||
| f"{cls} instance named of {name} has been created, " | |||
| "the method `get_instance` should not accept any other " | |||
| "arguments" | |||
| ) | |||
| # Get latest instantiated instance or root instance. | |||
| _release_lock() | |||
| return instance_dict[name] | |||
| @classmethod | |||
| def get_current_instance(cls): | |||
| """Get latest created instance. | |||
| Before calling ``get_current_instance``, The subclass must have called | |||
| ``get_instance(xxx)`` at least once. | |||
| Examples | |||
| >>> instance = GlobalAccessible.get_current_instance() | |||
| AssertionError: At least one of name and current needs to be set | |||
| >>> instance = GlobalAccessible.get_instance('name1') | |||
| >>> instance.instance_name | |||
| name1 | |||
| >>> instance = GlobalAccessible.get_current_instance() | |||
| >>> instance.instance_name | |||
| name1 | |||
| Returns: | |||
| object: Latest created instance. | |||
| """ | |||
| _accquire_lock() | |||
| if not cls._instance_dict: | |||
| raise RuntimeError( | |||
| f"Before calling {cls.__name__}.get_current_instance(), you " | |||
| "should call get_instance(name=xxx) at least once." | |||
| ) | |||
| name = next(iter(reversed(cls._instance_dict))) | |||
| _release_lock() | |||
| return cls._instance_dict[name] | |||
| @classmethod | |||
| def check_instance_created(cls, name: str) -> bool: | |||
| """Check whether the name corresponding instance exists. | |||
| Args: | |||
| name (str): Name of instance. | |||
| Returns: | |||
| bool: Whether the name corresponding instance exists. | |||
| """ | |||
| return name in cls._instance_dict | |||
| @property | |||
| def instance_name(self) -> str: | |||
| """Get the name of instance. | |||
| Returns: | |||
| str: Name of instance. | |||
| """ | |||
| return self._instance_name | |||
| @@ -0,0 +1,180 @@ | |||
| from typing import List, Any, Union, Tuple, Optional | |||
| import numpy as np | |||
| def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[Any]: | |||
| """ | |||
| Flattens a nested list at the first level. | |||
| Parameters | |||
| ---------- | |||
| nested_list : List[Union[Any, List[Any], Tuple[Any, ...]]] | |||
| A list which might contain sublists or tuples at the first level. | |||
| Returns | |||
| ------- | |||
| List[Any] | |||
| A flattened version of the input list, where only the first | |||
| level of sublists and tuples are reduced. | |||
| """ | |||
| if not isinstance(nested_list, list): | |||
| return nested_list | |||
| flattened_list = [] | |||
| for item in nested_list: | |||
| if isinstance(item, (list, tuple)): | |||
| flattened_list.extend(item) | |||
| else: | |||
| flattened_list.append(item) | |||
| return flattened_list | |||
| def reform_list( | |||
| flattened_list: List[Any], structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]] | |||
| ) -> List[List[Any]]: | |||
| """ | |||
| Reform the list based on the structure of ``structured_list``. | |||
| Parameters | |||
| ---------- | |||
| flattened_list : List[Any] | |||
| A flattened list of elements. | |||
| structured_list : List[Union[Any, List[Any], Tuple[Any, ...]]] | |||
| A list that reflects the desired structure, which may contain sublists or tuples. | |||
| Returns | |||
| ------- | |||
| List[List[Any]] | |||
| A reformed list that mimics the structure of ``structured_list``. | |||
| """ | |||
| if not isinstance(structured_list[0], (list, tuple)): | |||
| return flattened_list | |||
| reformed_list = [] | |||
| idx_start = 0 | |||
| for elem in structured_list: | |||
| idx_end = idx_start + len(elem) | |||
| reformed_list.append(flattened_list[idx_start:idx_end]) | |||
| idx_start = idx_end | |||
| return reformed_list | |||
| def hamming_dist(pred_pseudo_label: List[Any], candidates: List[List[Any]]) -> np.ndarray: | |||
| """ | |||
| Compute the Hamming distance between two arrays. | |||
| Parameters | |||
| ---------- | |||
| pred_pseudo_label : List[Any] | |||
| Pseudo-labels of an example. | |||
| candidates : List[List[Any]] | |||
| Multiple possible candidates. | |||
| Returns | |||
| ------- | |||
| np.ndarray | |||
| Hamming distances computed for each candidate. | |||
| """ | |||
| pred_pseudo_label = np.array(pred_pseudo_label) | |||
| candidates = np.array(candidates) | |||
| # Ensuring that pred_pseudo_label is broadcastable to the shape of candidates | |||
| pred_pseudo_label = np.expand_dims(pred_pseudo_label, 0) | |||
| return np.sum(pred_pseudo_label != candidates, axis=1) | |||
| def confidence_dist(pred_prob: List[np.ndarray], candidates_idxs: List[List[Any]]) -> np.ndarray: | |||
| """ | |||
| Compute the confidence distance between prediction probabilities and candidates. | |||
| Parameters | |||
| ---------- | |||
| pred_prob : List[np.ndarray] | |||
| Prediction probability distributions, each element is an ndarray | |||
| representing the probability distribution of a particular prediction. | |||
| candidates_idxs : List[List[Any]] | |||
| Multiple possible candidates' indices. | |||
| Returns | |||
| ------- | |||
| np.ndarray | |||
| Confidence distances computed for each candidate. | |||
| """ | |||
| pred_prob = np.clip(pred_prob, 1e-9, 1) | |||
| _, cols = np.indices((len(candidates_idxs), len(candidates_idxs[0]))) | |||
| return 1 - np.prod(pred_prob[cols, candidates_idxs], axis=1) | |||
| def to_hashable(x: Union[List[Any], Any]) -> Union[Tuple[Any, ...], Any]: | |||
| """ | |||
| Convert a nested list to a nested tuple so it is hashable. | |||
| Parameters | |||
| ---------- | |||
| x : Union[List[Any], Any] | |||
| A potentially nested list to convert to a tuple. | |||
| Returns | |||
| ------- | |||
| Union[Tuple[Any, ...], Any] | |||
| The input converted to a tuple if it was a list, | |||
| otherwise the original input. | |||
| """ | |||
| if isinstance(x, list): | |||
| return tuple(to_hashable(item) for item in x) | |||
| return x | |||
| def restore_from_hashable(x): | |||
| """ | |||
| Convert a nested tuple back to a nested list. | |||
| Parameters | |||
| ---------- | |||
| x : Union[Tuple[Any, ...], Any] | |||
| A potentially nested tuple to convert to a list. | |||
| Returns | |||
| ------- | |||
| Union[List[Any], Any] | |||
| The input converted to a list if it was a tuple, | |||
| otherwise the original input. | |||
| """ | |||
| if isinstance(x, tuple): | |||
| return [restore_from_hashable(item) for item in x] | |||
| return x | |||
| def tab_data_to_tuple( | |||
| X: Union[List[Any], Any], y: Union[List[Any], Any], reasoning_result: Optional[Any] = 0 | |||
| ) -> Tuple[List[List[Any]], List[List[Any]], List[Any]]: | |||
| """ | |||
| Convert a tabular data to a tuple by adding a dimension to each element of | |||
| X and y. The tuple contains three elements: data, label, and reasoning result. | |||
| If X is None, return None. | |||
| Parameters | |||
| ---------- | |||
| X : Union[List[Any], Any] | |||
| The data. | |||
| y : Union[List[Any], Any] | |||
| The label. | |||
| reasoning_result : Any, optional | |||
| The reasoning result, by default 0. | |||
| Returns | |||
| ------- | |||
| Tuple[List[List[Any]], List[List[Any]], List[Any]] | |||
| A tuple of (data, label, reasoning_result). | |||
| """ | |||
| if X is None: | |||
| return None | |||
| if len(X) != len(y): | |||
| raise ValueError( | |||
| "The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)) | |||
| ) | |||
| return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) | |||
| @@ -1,186 +0,0 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # Copyright (C) 2020 Freecss All rights reserved. | |||
| # | |||
| # File Name :data_generator.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2020/04/02 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| from itertools import product | |||
| import math | |||
| import numpy as np | |||
| import random | |||
| import pickle as pk | |||
| import random | |||
| from multiprocessing import Pool | |||
| import copy | |||
| #def hamming_code_generator(data_len, p_len): | |||
| # ret = [] | |||
| # for data in product((0, 1), repeat=data_len): | |||
| # p_idxs = [2 ** i for i in range(p_len)] | |||
| # total_len = data_len + p_len | |||
| # data_idx = 0 | |||
| # hamming_code = [] | |||
| # for idx in range(total_len): | |||
| # if idx + 1 in p_idxs: | |||
| # hamming_code.append(0) | |||
| # else: | |||
| # hamming_code.append(data[data_idx]) | |||
| # data_idx += 1 | |||
| # | |||
| # for idx in range(total_len): | |||
| # if idx + 1 in p_idxs: | |||
| # for i in range(total_len): | |||
| # if (i + 1) & (idx + 1) != 0: | |||
| # hamming_code[idx] ^= hamming_code[i] | |||
| # #hamming_code = "".join([str(x) for x in hamming_code]) | |||
| # ret.append(hamming_code) | |||
| # return ret | |||
| def code_generator(code_len, code_num, letter_num = 2): | |||
| codes = list(product(list(range(letter_num)), repeat = code_len)) | |||
| random.shuffle(codes) | |||
| return codes[:code_num] | |||
| def hamming_distance_static(codes): | |||
| min_dist = len(codes) | |||
| avg_dist = 0. | |||
| avg_min_dist = 0. | |||
| relation_num = 0. | |||
| for code1 in codes: | |||
| tmp_min_dist = len(codes) | |||
| for code2 in codes: | |||
| if code1 == code2: | |||
| continue | |||
| dist = 0 | |||
| relation_num += 1 | |||
| for c1, c2 in zip(code1, code2): | |||
| if c1 != c2: | |||
| dist += 1 | |||
| avg_dist += dist | |||
| if tmp_min_dist > dist: | |||
| tmp_min_dist = dist | |||
| avg_min_dist += tmp_min_dist | |||
| if min_dist > tmp_min_dist: | |||
| min_dist = tmp_min_dist | |||
| return avg_dist / relation_num, avg_min_dist / len(codes) | |||
| def generate_cosin_data(codes, err, repeat, letter_num): | |||
| Y = np.random.random(100000) * letter_num * 3 - 3 | |||
| X = np.random.random(100000) * 20 - 10 | |||
| data_X = np.concatenate((X.reshape(-1, 1), Y.reshape(-1, 1)), axis = 1) | |||
| samples = {} | |||
| all_sign = list(set(sum([[c for c in code] for code in codes], []))) | |||
| for d, sign in enumerate(all_sign): | |||
| labels = np.logical_and(Y < np.cos(X) + 2 * d, Y > np.cos(X) + 2 * d - 2) | |||
| samples[sign] = data_X[labels] | |||
| data = [] | |||
| labels = [] | |||
| count = 0 | |||
| for _ in range(repeat): | |||
| if (count > 100000): | |||
| break | |||
| for code in codes: | |||
| tmp = [] | |||
| count += 1 | |||
| for d in code: | |||
| if random.random() < err: | |||
| candidates = copy.deepcopy(all_sign) | |||
| candidates.remove(d) | |||
| d = candidates[random.randint(0, letter_num - 2)] | |||
| idx = random.randint(0, len(samples[d]) - 1) | |||
| tmp.append(samples[d][idx]) | |||
| data.append(tmp) | |||
| labels.append(code) | |||
| data = np.array(data) | |||
| labels = np.array(labels) | |||
| return data, labels | |||
| #codes = """110011001 | |||
| #100011001 | |||
| #101101101 | |||
| #011111001 | |||
| #100100001 | |||
| #111111101 | |||
| #101110001 | |||
| #111100101 | |||
| #101000101 | |||
| #001001101 | |||
| #111110101 | |||
| #100101001 | |||
| #010010101 | |||
| #110100101 | |||
| #001111101 | |||
| #111111001""" | |||
| #codes = codes.split() | |||
| def generate_data_via_codes(codes, err, letter_num): | |||
| #codes = code_generator(code_len, code_num) | |||
| data, labels = generate_cosin_data(codes, err, 100000, letter_num) | |||
| return data, labels | |||
| def generate_data(params): | |||
| code_len = params["code_len"] | |||
| times = params["times"] | |||
| p = params["p"] | |||
| code_num = params["code_num"] | |||
| err = p / 20. | |||
| codes = code_generator(code_len, code_num) | |||
| data, labels = generate_cosin_data(codes, err) | |||
| data_name = "code_%d_%d" % (code_len, code_num) | |||
| pk.dump((codes, data, labels), open("generated_data/%d_%s_%.2f.pk" % (times, data_name, err), "wb")) | |||
| return True | |||
| def generate_multi_data(): | |||
| pool = Pool(64) | |||
| params_list = [] | |||
| #for code_len in [7, 9, 11, 13, 15]: | |||
| for code_len in [7, 11, 15]: | |||
| for times in range(20): | |||
| for p in range(0, 11): | |||
| for code_num_power in range(1, code_len): | |||
| code_num = 2 ** code_num_power | |||
| params_list.append({"code_len" : code_len, "times" : times, "p" : p, "code_num" : code_num}) | |||
| return list(pool.map(generate_data, params_list)) | |||
| def read_lexicon(file_path): | |||
| ret = [] | |||
| with open(file_path) as fin: | |||
| ret = [s.strip() for s in fin] | |||
| all_sign = list(set(sum([[c for c in s] for s in ret], []))) | |||
| #ret = ["".join(str(all_sign.index(t)) for t in tmp) for tmp in ret] | |||
| return ret, len(all_sign) | |||
| import os | |||
| if __name__ == "__main__": | |||
| for root, dirs, files in os.walk("lexicons"): | |||
| if root != "lexicons": | |||
| continue | |||
| for file_name in files: | |||
| file_path = os.path.join(root, file_name) | |||
| codes, letter_num = read_lexicon(file_path) | |||
| data, labels = generate_data_via_codes(codes, 0, letter_num) | |||
| save_path = os.path.join("dataset", file_name.split(".")[0] + ".pk") | |||
| pk.dump((data, labels, codes), open(save_path, "wb")) | |||
| #res = read_lexicon("add2.txt") | |||
| #print(res) | |||
| exit(0) | |||
| generate_multi_data() | |||
| exit() | |||
| @@ -0,0 +1,7 @@ | |||
| abl.bridge | |||
| ================== | |||
| .. automodule:: abl.bridge | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -0,0 +1,18 @@ | |||
| abl.data | |||
| =================== | |||
| ``structures`` | |||
| -------------- | |||
| .. autoclass:: abl.data.structures.ListData | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| ``evaluation`` | |||
| -------------- | |||
| .. automodule:: abl.data.evaluation | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -0,0 +1,20 @@ | |||
| abl.learning | |||
| ================== | |||
| .. autoclass:: abl.learning.ABLModel | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| .. autoclass:: abl.learning.BasicNN | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| ``torch_dataset`` | |||
| ----------------- | |||
| .. automodule:: abl.learning.torch_dataset | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -0,0 +1,7 @@ | |||
| abl.reasoning | |||
| ================== | |||
| .. automodule:: abl.reasoning | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -0,0 +1,7 @@ | |||
| abl.utils | |||
| ================== | |||
| .. automodule:: abl.utils | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -0,0 +1,299 @@ | |||
| Handwritten Equation Decipherment (HED) | |||
| ======================================= | |||
| .. raw:: html | |||
| <p>For detailed code implementation, please view on <a class="reference external" href="https://github.com/AbductiveLearning/ABL-Package/tree/Dev/examples/hed" target="_blank">GitHub</a>.</p> | |||
| Below shows an implementation of `Handwritten Equation | |||
| Decipherment <https://proceedings.neurips.cc/paper_files/paper/2019/file/9c19a2aa1d84e04b0bd4bc888792bd1e-Paper.pdf>`__. | |||
| In this task, the handwritten equations are given, which consist of | |||
| sequential pictures of characters. The equations are generated with | |||
| unknown operation rules from images of symbols (‘0’, ‘1’, ‘+’ and ‘=’), | |||
| and each equation is associated with a label indicating whether the | |||
| equation is correct (i.e., positive) or not (i.e., negative). Also, we | |||
| are given a knowledge base which involves the structure of the equations | |||
| and a recursive definition of bit-wise operations. The task is to learn | |||
| from a training set of above mentioned equations and then to predict | |||
| labels of unseen equations. | |||
| Intuitively, we first use a machine learning model (learning part) to | |||
| obtain the pseudo-labels (‘0’, ‘1’, ‘+’ and ‘=’) for the observed | |||
| pictures. We then use the knowledge base (reasoning part) to perform | |||
| abductive reasoning so as to yield ground hypotheses as possible | |||
| explanations to the observed facts, suggesting some pseudo-labels to be | |||
| revised. This process enables us to further update the machine learning | |||
| model. | |||
| .. code:: ipython3 | |||
| # Import necessary libraries and modules | |||
| import os.path as osp | |||
| import matplotlib.pyplot as plt | |||
| import torch | |||
| import torch.nn as nn | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.utils import ABLLogger, print_log | |||
| from bridge import HedBridge | |||
| from consistency_metric import ConsistencyMetric | |||
| from datasets import get_dataset, split_equation | |||
| from models.nn import SymbolNet | |||
| from reasoning import HedKB, HedReasoner | |||
| Working with Data | |||
| ----------------- | |||
| First, we get the datasets of handwritten equations: | |||
| .. code:: ipython3 | |||
| total_train_data = get_dataset(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_dataset(train=False) | |||
| The dataset are shown below: | |||
| .. code:: ipython3 | |||
| true_train_equation = train_data[1] | |||
| false_train_equation = train_data[0] | |||
| print(f"Equations in the dataset is organized by equation length, " + | |||
| f"from {min(train_data[0].keys())} to {max(train_data[0].keys())}") | |||
| print() | |||
| true_train_equation_with_length_5 = true_train_equation[5] | |||
| false_train_equation_with_length_5 = false_train_equation[5] | |||
| print(f"For each euqation length, there are {len(true_train_equation_with_length_5)} " + | |||
| f"true equation and {len(false_train_equation_with_length_5)} false equation " + | |||
| f"in the training set") | |||
| true_val_equation = val_data[1] | |||
| false_val_equation = val_data[0] | |||
| true_val_equation_with_length_5 = true_val_equation[5] | |||
| false_val_equation_with_length_5 = false_val_equation[5] | |||
| print(f"For each euqation length, there are {len(true_val_equation_with_length_5)} " + | |||
| f"true equation and {len(false_val_equation_with_length_5)} false equation " + | |||
| f"in the validation set") | |||
| true_test_equation = test_data[1] | |||
| false_test_equation = test_data[0] | |||
| true_test_equation_with_length_5 = true_test_equation[5] | |||
| false_test_equation_with_length_5 = false_test_equation[5] | |||
| print(f"For each euqation length, there are {len(true_test_equation_with_length_5)} " + | |||
| f"true equation and {len(false_test_equation_with_length_5)} false equation " + | |||
| f"in the test set") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Equations in the dataset is organized by equation length, from 5 to 26 | |||
| For each euqation length, there are 225 true equation and 225 false equation in the training set | |||
| For each euqation length, there are 75 true equation and 75 false equation in the validation set | |||
| For each euqation length, there are 300 true equation and 300 false equation in the test set | |||
| As illustrations, we show four equations in the training dataset: | |||
| .. code:: ipython3 | |||
| true_train_equation_with_length_5 = true_train_equation[5] | |||
| true_train_equation_with_length_8 = true_train_equation[8] | |||
| print(f"First true equation with length 5 in the training dataset:") | |||
| for i, x in enumerate(true_train_equation_with_length_5[0]): | |||
| plt.subplot(1, 5, i+1) | |||
| plt.axis('off') | |||
| plt.imshow(x.squeeze(), cmap='gray') | |||
| plt.show() | |||
| print(f"First true equation with length 8 in the training dataset:") | |||
| for i, x in enumerate(true_train_equation_with_length_8[0]): | |||
| plt.subplot(1, 8, i+1) | |||
| plt.axis('off') | |||
| plt.imshow(x.squeeze(), cmap='gray') | |||
| plt.show() | |||
| false_train_equation_with_length_5 = false_train_equation[5] | |||
| false_train_equation_with_length_8 = false_train_equation[8] | |||
| print(f"First false equation with length 5 in the training dataset:") | |||
| for i, x in enumerate(false_train_equation_with_length_5[0]): | |||
| plt.subplot(1, 5, i+1) | |||
| plt.axis('off') | |||
| plt.imshow(x.squeeze(), cmap='gray') | |||
| plt.show() | |||
| print(f"First false equation with length 8 in the training dataset:") | |||
| for i, x in enumerate(false_train_equation_with_length_8[0]): | |||
| plt.subplot(1, 8, i+1) | |||
| plt.axis('off') | |||
| plt.imshow(x.squeeze(), cmap='gray') | |||
| plt.show() | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| First true equation with length 5 in the training dataset: | |||
| .. image:: ../_static/img/hed_dataset1.png | |||
| :width: 300px | |||
| .. code:: none | |||
| :class: code-out | |||
| First true equation with length 8 in the training dataset: | |||
| .. image:: ../_static/img/hed_dataset2.png | |||
| :width: 480px | |||
| .. code:: none | |||
| :class: code-out | |||
| First false equation with length 5 in the training dataset: | |||
| .. image:: ../_static/img/hed_dataset3.png | |||
| :width: 300px | |||
| .. code:: none | |||
| :class: code-out | |||
| First false equation with length 8 in the training dataset: | |||
| .. image:: ../_static/img/hed_dataset4.png | |||
| :width: 480px | |||
| Building the Learning Part | |||
| -------------------------- | |||
| To build the learning part, we need to first build a machine learning | |||
| base model. We use SymbolNet, and encapsulate it within a ``BasicNN`` | |||
| object to create the base model. ``BasicNN`` is a class that | |||
| encapsulates a PyTorch model, transforming it into a base model with an | |||
| sklearn-style interface. | |||
| .. code:: ipython3 | |||
| # class of symbol may be one of ['0', '1', '+', '='], total of 4 classes | |||
| cls = SymbolNet(num_classes=4) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4) | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| base_model = BasicNN( | |||
| cls, | |||
| loss_fn, | |||
| optimizer, | |||
| device=device, | |||
| batch_size=32, | |||
| num_epochs=1, | |||
| stop_loss=None, | |||
| ) | |||
| However, the base model built above deals with instance-level data | |||
| (i.e., individual images), and can not directly deal with example-level | |||
| data (i.e., a list of images comprising the equation). Therefore, we | |||
| wrap the base model into ``ABLModel``, which enables the learning part | |||
| to train, test, and predict on example-level data. | |||
| .. code:: ipython3 | |||
| model = ABLModel(base_model) | |||
| Building the Reasoning Part | |||
| --------------------------- | |||
| In the reasoning part, we first build a knowledge base. As mentioned | |||
| before, the knowledge base in this task involves the structure of the | |||
| equations and a recursive definition of bit-wise operations, which are | |||
| defined in Prolog file ``examples/hed/reasoning/BK.pl`` | |||
| and ``examples/hed/reasoning/learn_add.pl``, respectively. | |||
| Specifically, the knowledge about the structure of equations is a set of DCG | |||
| rules recursively define that a digit is a sequence of ‘0’ and ‘1’, and | |||
| equations share the structure of X+Y=Z, though the length of X, Y and Z | |||
| can be varied. The knowledge about bit-wise operations is a recursive | |||
| logic program, which reversely calculates X+Y, i.e., it operates on | |||
| X and Y digit-by-digit and from the last digit to the first. | |||
| The knowledge base is already built in ``HedKB``. | |||
| ``HedKB`` is derived from class ``PrologKB``, and is built upon the aformentioned Prolog | |||
| files. | |||
| .. code:: ipython3 | |||
| kb = HedKB() | |||
| .. note:: | |||
| Please notice that, the specific rules for calculating the | |||
| operations are undefined in the knowledge base, i.e., results of ‘0+0’, | |||
| ‘0+1’ and ‘1+1’ could be ‘0’, ‘1’, ‘00’, ‘01’ or even ‘10’. The missing | |||
| calculation rules are required to be learned from the data. Therefore, | |||
| ``HedKB`` incorporates methods for abducing rules from data. Users | |||
| interested can refer to the specific implementation of ``HedKB`` in | |||
| ``examples/hed/reasoning/reasoning.py`` | |||
| Then, we create a reasoner. Due to the indeterminism of abductive | |||
| reasoning, there could be multiple candidates compatible to the | |||
| knowledge base. When this happens, reasoner can minimize inconsistencies | |||
| between the knowledge base and pseudo-labels predicted by the learning | |||
| part, and then return only one candidate that has the highest | |||
| consistency. | |||
| In this task, we create the reasoner by instantiating the class | |||
| ``HedReasoner``, which is a reasoner derived from ``Reasoner`` and | |||
| tailored specifically for this task. ``HedReasoner`` leverages `ZOOpt | |||
| library <https://github.com/polixir/ZOOpt>`__ for acceleration, and has | |||
| designed a specific strategy to better harness ZOOpt’s capabilities. | |||
| Additionally, methods for abducing rules from data have been | |||
| incorporated. Users interested can refer to the specific implementation | |||
| of ``HedReasoner`` in ``reasoning/reasoning.py``. | |||
| .. code:: ipython3 | |||
| reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=10) | |||
| Building Evaluation Metrics | |||
| --------------------------- | |||
| Next, we set up evaluation metrics. These metrics will be used to | |||
| evaluate the model performance during training and testing. | |||
| Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are | |||
| used to evaluate the accuracy of the machine learning model’s | |||
| predictions and the accuracy of the final reasoning results, | |||
| respectively. | |||
| .. code:: ipython3 | |||
| # Set up metrics | |||
| metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")] | |||
| Bridge Learning and Reasoning | |||
| ----------------------------- | |||
| Now, the last step is to bridge the learning and reasoning part. We | |||
| proceed this step by creating an instance of ``HedBridge``, which is | |||
| derived from ``SimpleBridge`` and tailored specific for this task. | |||
| .. code:: ipython3 | |||
| bridge = HedBridge(model, reasoner, metric_list) | |||
| Perform pretraining, training and testing by invoking the ``pretrain``, ``train`` and ``test`` methods of ``HedBridge``. | |||
| .. code:: ipython3 | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| bridge.pretrain("./weights") | |||
| bridge.train(train_data, val_data, save_dir=weights_dir) | |||
| bridge.test(test_data) | |||
| @@ -0,0 +1,472 @@ | |||
| Handwritten Formula (HWF) | |||
| ========================= | |||
| .. raw:: html | |||
| <p>For detailed code implementation, please view on <a class="reference external" href="https://github.com/AbductiveLearning/ABL-Package/tree/Dev/examples/hwf" target="_blank">GitHub</a>.</p> | |||
| Below shows an implementation of `Handwritten | |||
| Formula <https://arxiv.org/abs/2006.06649>`__. In this | |||
| task, handwritten images of decimal formulas and their computed results | |||
| are given, alongwith a domain knowledge base containing information on | |||
| how to compute the decimal formula. The task is to recognize the symbols | |||
| (which can be digits or operators ‘+’, ‘-’, ‘×’, ‘÷’) of handwritten | |||
| images and accurately determine their results. | |||
| Intuitively, we first use a machine learning model (learning part) to | |||
| convert the input images to symbols (we call them pseudo-labels), and | |||
| then use the knowledge base (reasoning part) to calculate the results of | |||
| these symbols. Since we do not have ground-truth of the symbols, in | |||
| Abductive Learning, the reasoning part will leverage domain knowledge | |||
| and revise the initial symbols yielded by the learning part through | |||
| abductive reasoning. This process enables us to further update the | |||
| machine learning model. | |||
| .. code:: ipython3 | |||
| # Import necessary libraries and modules | |||
| import os.path as osp | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| from abl.bridge import SimpleBridge | |||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.reasoning import KBBase, Reasoner | |||
| from abl.utils import ABLLogger, print_log | |||
| from datasets import get_dataset | |||
| from models.nn import SymbolNet | |||
| Working with Data | |||
| ----------------- | |||
| First, we get the training and testing datasets: | |||
| .. code:: ipython3 | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| Both ``train_data`` and ``test_data`` have the same structures: tuples | |||
| with three components: X (list where each element is a list of images), | |||
| gt_pseudo_label (list where each element is a list of symbols, i.e., | |||
| pseudo-labels) and Y (list where each element is the computed result). | |||
| The length and structures of datasets are illustrated as follows. | |||
| .. note:: | |||
| ``gt_pseudo_label`` is only used to evaluate the performance of | |||
| the learning part but not to train the model. | |||
| .. code:: ipython3 | |||
| print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y") | |||
| print() | |||
| train_X, train_gt_pseudo_label, train_Y = train_data | |||
| print(f"Length of X, gt_pseudo_label, Y in train_data: " + | |||
| f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}") | |||
| test_X, test_gt_pseudo_label, test_Y = test_data | |||
| print(f"Length of X, gt_pseudo_label, Y in test_data: " + | |||
| f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}") | |||
| print() | |||
| X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0] | |||
| print(f"X is a {type(train_X).__name__}, " + | |||
| f"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.") | |||
| print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " + | |||
| f"with each element being a {type(gt_pseudo_label_0).__name__} " + | |||
| f"of {type(gt_pseudo_label_0[0]).__name__}.") | |||
| print(f"Y is a {type(train_Y).__name__}, " + | |||
| f"with each element being a {type(Y_0).__name__}.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y | |||
| Length of X, gt_pseudo_label, Y in train_data: 10000, 10000, 10000 | |||
| Length of X, gt_pseudo_label, Y in test_data: 2000, 2000, 2000 | |||
| X is a list, with each element being a list of Tensor. | |||
| gt_pseudo_label is a list, with each element being a list of str. | |||
| Y is a list, with each element being a int. | |||
| The ith element of X, gt_pseudo_label, and Y together constitute the ith | |||
| data example. Here we use two of them (the 1001st and the 3001st) as | |||
| illstrations: | |||
| .. code:: ipython3 | |||
| X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000] | |||
| print(f"X in the 1001st data example (a list of images):") | |||
| for i, x in enumerate(X_1000): | |||
| plt.subplot(1, len(X_1000), i+1) | |||
| plt.axis('off') | |||
| plt.imshow(x.squeeze(), cmap='gray') | |||
| plt.show() | |||
| print(f"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}") | |||
| print(f"Y in the 1001st data example (the computed result): {Y_1000}") | |||
| print() | |||
| X_3000, gt_pseudo_label_3000, Y_3000 = train_X[3000], train_gt_pseudo_label[3000], train_Y[3000] | |||
| print(f"X in the 3001st data example (a list of images):") | |||
| for i, x in enumerate(X_3000): | |||
| plt.subplot(1, len(X_3000), i+1) | |||
| plt.axis('off') | |||
| plt.imshow(x.squeeze(), cmap='gray') | |||
| plt.show() | |||
| print(f"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}") | |||
| print(f"Y in the 3001st data example (the computed result): {Y_3000}") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| X in the 1001st data example (a list of images): | |||
| .. image:: ../_static/img/hwf_dataset1.png | |||
| :width: 210px | |||
| .. code:: none | |||
| :class: code-out | |||
| gt_pseudo_label in the 1001st data example (a list of pseudo-labels): ['5', '-', '3'] | |||
| Y in the 1001st data example (the computed result): 2 | |||
| .. code:: none | |||
| :class: code-out | |||
| X in the 3001st data example (a list of images): | |||
| .. image:: ../_static/img/hwf_dataset2.png | |||
| :width: 350px | |||
| .. code:: none | |||
| :class: code-out | |||
| gt_pseudo_label in the 3001st data example (a list of pseudo-labels): ['4', '/', '6', '*', '5'] | |||
| Y in the 3001st data example (the computed result): 3.333333333333333 | |||
| .. note:: | |||
| The symbols in the HWF dataset can be one of digits or operators | |||
| '+', '-', '×', '÷'. | |||
| We may see that, in the 1001st data example, the length of the | |||
| formula is 3, while in the 3001st data example, the length of the | |||
| formula is 5. In the HWF dataset, the length of the formula varies from | |||
| 1 to 7. | |||
| Building the Learning Part | |||
| -------------------------- | |||
| To build the learning part, we need to first build a machine learning | |||
| base model. We use SymbolNet, and encapsulate it within a ``BasicNN`` | |||
| object to create the base model. ``BasicNN`` is a class that | |||
| encapsulates a PyTorch model, transforming it into a base model with an | |||
| sklearn-style interface. | |||
| .. code:: ipython3 | |||
| # class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 14 classes | |||
| cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| base_model = BasicNN( | |||
| model=cls, | |||
| loss_fn=loss_fn, | |||
| optimizer=optimizer, | |||
| device=device, | |||
| batch_size=128, | |||
| num_epochs=3, | |||
| ) | |||
| ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which | |||
| are used to predict the class index and the probabilities of each class | |||
| for images. As shown below: | |||
| .. code:: ipython3 | |||
| data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)] | |||
| pred_idx = base_model.predict(X=data_instances) | |||
| print(f"Predicted class index for a batch of 32 instances: " + | |||
| f"{type(pred_idx).__name__} with shape {pred_idx.shape}") | |||
| pred_prob = base_model.predict_proba(X=data_instances) | |||
| print(f"Predicted class probabilities for a batch of 32 instances: " + | |||
| f"{type(pred_prob).__name__} with shape {pred_prob.shape}") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Predicted class index for a batch of 32 instances: ndarray with shape (32,) | |||
| Predicted class probabilities for a batch of 32 instances: ndarray with shape (32, 14) | |||
| However, the base model built above deals with instance-level data | |||
| (i.e., individual images), and can not directly deal with example-level | |||
| data (i.e., a list of images comprising the formula). Therefore, we wrap | |||
| the base model into ``ABLModel``, which enables the learning part to | |||
| train, test, and predict on example-level data. | |||
| .. code:: ipython3 | |||
| model = ABLModel(base_model) | |||
| As an illustration, consider this example of training on example-level | |||
| data using the ``predict`` method in ``ABLModel``. In this process, the | |||
| method accepts data examples as input and outputs the class labels and | |||
| the probabilities of each class for all instances within these data | |||
| examples. | |||
| .. code:: ipython3 | |||
| from abl.data.structures import ListData | |||
| # ListData is a data structure provided by ABL-Package that can be used to organize data examples | |||
| data_examples = ListData() | |||
| # We use the first 1001st and 3001st data examples in the training set as an illustration | |||
| data_examples.X = [X_1000, X_3000] | |||
| data_examples.gt_pseudo_label = [gt_pseudo_label_1000, gt_pseudo_label_3000] | |||
| data_examples.Y = [Y_1000, Y_3000] | |||
| # Perform prediction on the two data examples | |||
| # Remind that, in the 1001st data example, the length of the formula is 3, | |||
| # while in the 3001st data example, the length of the formula is 5. | |||
| pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob'] | |||
| print(f"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \n" + | |||
| f"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, "+ | |||
| f"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\n") | |||
| print(f"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \n" | |||
| f"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, " + | |||
| f"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Predicted class labels for the 100 data examples: a list of length 2, | |||
| the first element is a ndarray of shape (3,), and the second element is a ndarray of shape (5,). | |||
| Predicted class probabilities for the 100 data examples: a list of length 2, | |||
| the first element is a ndarray of shape (3, 14), and the second element is a ndarray of shape (5, 14). | |||
| Building the Reasoning Part | |||
| --------------------------- | |||
| In the reasoning part, we first build a knowledge base which contain | |||
| information on how to perform addition operations. We build it by | |||
| creating a subclass of ``KBBase``. In the derived subclass, we | |||
| initialize the ``pseudo_label_list`` parameter specifying list of | |||
| possible pseudo-labels, and override the ``logic_forward`` function | |||
| defining how to perform (deductive) reasoning. | |||
| .. code:: ipython3 | |||
| class HwfKB(KBBase): | |||
| def __init__(self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"]): | |||
| super().__init__(pseudo_label_list) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: | |||
| return False | |||
| return True | |||
| # Implement the deduction function | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return np.inf | |||
| return eval("".join(formula)) | |||
| kb = HwfKB() | |||
| The knowledge base can perform logical reasoning (both deductive | |||
| reasoning and abductive reasoning). Below is an example of performing | |||
| (deductive) reasoning, and users can refer to :ref:`Performing abductive | |||
| reasoning in the knowledge base <kb-abd>` for details of abductive reasoning. | |||
| .. code:: ipython3 | |||
| pseudo_labels = ["1", "-", "2", "*", "5"] | |||
| reasoning_result = kb.logic_forward(pseudo_labels) | |||
| print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Reasoning result of pseudo-labels ['1', '-', '2', '*', '5'] is -9. | |||
| .. note:: | |||
| In addition to building a knowledge base based on ``KBBase``, we | |||
| can also establish a knowledge base with a ground KB using ``GroundKB``. | |||
| The corresponding code can be found in the ``examples/hwf/main.py`` file. Those | |||
| interested are encouraged to examine it for further insights. | |||
| Also, when building the knowledge base, we can also set the | |||
| ``max_err`` parameter during initialization, which is shown in the | |||
| ``examples/hwf/main.py`` file. This parameter specifies the upper tolerance limit | |||
| when comparing the similarity between the reasoning result of pseudo-labels and | |||
| the ground truth during abductive reasoning, with a default | |||
| value of 1e-10. | |||
| Then, we create a reasoner by instantiating the class ``Reasoner``. Due | |||
| to the indeterminism of abductive reasoning, there could be multiple | |||
| candidates compatible to the knowledge base. When this happens, reasoner | |||
| can minimize inconsistencies between the knowledge base and | |||
| pseudo-labels predicted by the learning part, and then return only one | |||
| candidate that has the highest consistency. | |||
| .. code:: ipython3 | |||
| reasoner = Reasoner(kb) | |||
| .. note:: | |||
| During creating reasoner, the definition of “consistency” can be | |||
| customized within the ``dist_func`` parameter. In the code above, we | |||
| employ a consistency measurement based on confidence, which calculates | |||
| the consistency between the data example and candidates based on the | |||
| confidence derived from the predicted probability. In ``examples/hwf/main.py``, we | |||
| provide options for utilizing other forms of consistency measurement. | |||
| Also, during process of inconsistency minimization, we can | |||
| leverage `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for | |||
| acceleration. Options for this are also available in ``examples/hwf/main.py``. Those | |||
| interested are encouraged to explore these features. | |||
| Building Evaluation Metrics | |||
| --------------------------- | |||
| Next, we set up evaluation metrics. These metrics will be used to | |||
| evaluate the model performance during training and testing. | |||
| Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are | |||
| used to evaluate the accuracy of the machine learning model’s | |||
| predictions and the accuracy of the final reasoning results, | |||
| respectively. | |||
| .. code:: ipython3 | |||
| metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] | |||
| Bridge Learning and Reasoning | |||
| ----------------------------- | |||
| Now, the last step is to bridge the learning and reasoning part. We | |||
| proceed this step by creating an instance of ``SimpleBridge``. | |||
| .. code:: ipython3 | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| Perform training and testing by invoking the ``train`` and ``test`` | |||
| methods of ``SimpleBridge``. | |||
| .. code:: ipython3 | |||
| # Build logger | |||
| print_log("Abductive Learning on the HWF example.", logger="current") | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| bridge.train(train_data, train_data, loops=3, segment_size=1000, save_dir=weights_dir) | |||
| bridge.test(test_data) | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| abl - INFO - Abductive Learning on the HWF example. | |||
| abl - INFO - loop(train) [1/3] segment(train) [1/10] | |||
| abl - INFO - model loss: 0.00024 | |||
| abl - INFO - loop(train) [1/3] segment(train) [2/10] | |||
| abl - INFO - model loss: 0.00053 | |||
| abl - INFO - loop(train) [1/3] segment(train) [3/10] | |||
| abl - INFO - model loss: 0.00260 | |||
| abl - INFO - loop(train) [1/3] segment(train) [4/10] | |||
| abl - INFO - model loss: 0.00162 | |||
| abl - INFO - loop(train) [1/3] segment(train) [5/10] | |||
| abl - INFO - model loss: 0.00073 | |||
| abl - INFO - loop(train) [1/3] segment(train) [6/10] | |||
| abl - INFO - model loss: 0.00055 | |||
| abl - INFO - loop(train) [1/3] segment(train) [7/10] | |||
| abl - INFO - model loss: 0.00148 | |||
| abl - INFO - loop(train) [1/3] segment(train) [8/10] | |||
| abl - INFO - model loss: 0.00034 | |||
| abl - INFO - loop(train) [1/3] segment(train) [9/10] | |||
| abl - INFO - model loss: 0.00167 | |||
| abl - INFO - loop(train) [1/3] segment(train) [10/10] | |||
| abl - INFO - model loss: 0.00185 | |||
| abl - INFO - Evaluation start: loop(val) [1] | |||
| abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 0.999 | |||
| abl - INFO - Saving model: loop(save) [1] | |||
| abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth | |||
| abl - INFO - loop(train) [2/3] segment(train) [1/10] | |||
| abl - INFO - model loss: 0.00219 | |||
| abl - INFO - loop(train) [2/3] segment(train) [2/10] | |||
| abl - INFO - model loss: 0.00069 | |||
| abl - INFO - loop(train) [2/3] segment(train) [3/10] | |||
| abl - INFO - model loss: 0.00013 | |||
| abl - INFO - loop(train) [2/3] segment(train) [4/10] | |||
| abl - INFO - model loss: 0.00013 | |||
| abl - INFO - loop(train) [2/3] segment(train) [5/10] | |||
| abl - INFO - model loss: 0.00248 | |||
| abl - INFO - loop(train) [2/3] segment(train) [6/10] | |||
| abl - INFO - model loss: 0.00010 | |||
| abl - INFO - loop(train) [2/3] segment(train) [7/10] | |||
| abl - INFO - model loss: 0.00020 | |||
| abl - INFO - loop(train) [2/3] segment(train) [8/10] | |||
| abl - INFO - model loss: 0.00076 | |||
| abl - INFO - loop(train) [2/3] segment(train) [9/10] | |||
| abl - INFO - model loss: 0.00061 | |||
| abl - INFO - loop(train) [2/3] segment(train) [10/10] | |||
| abl - INFO - model loss: 0.00117 | |||
| abl - INFO - Evaluation start: loop(val) [2] | |||
| abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 1.000 | |||
| abl - INFO - Saving model: loop(save) [2] | |||
| abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth | |||
| abl - INFO - loop(train) [3/3] segment(train) [1/10] | |||
| abl - INFO - model loss: 0.00120 | |||
| abl - INFO - loop(train) [3/3] segment(train) [2/10] | |||
| abl - INFO - model loss: 0.00114 | |||
| abl - INFO - loop(train) [3/3] segment(train) [3/10] | |||
| abl - INFO - model loss: 0.00071 | |||
| abl - INFO - loop(train) [3/3] segment(train) [4/10] | |||
| abl - INFO - model loss: 0.00027 | |||
| abl - INFO - loop(train) [3/3] segment(train) [5/10] | |||
| abl - INFO - model loss: 0.00017 | |||
| abl - INFO - loop(train) [3/3] segment(train) [6/10] | |||
| abl - INFO - model loss: 0.00018 | |||
| abl - INFO - loop(train) [3/3] segment(train) [7/10] | |||
| abl - INFO - model loss: 0.00141 | |||
| abl - INFO - loop(train) [3/3] segment(train) [8/10] | |||
| abl - INFO - model loss: 0.00099 | |||
| abl - INFO - loop(train) [3/3] segment(train) [9/10] | |||
| abl - INFO - model loss: 0.00145 | |||
| abl - INFO - loop(train) [3/3] segment(train) [10/10] | |||
| abl - INFO - model loss: 0.00215 | |||
| abl - INFO - Evaluation start: loop(val) [3] | |||
| abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 1.000 | |||
| abl - INFO - Saving model: loop(save) [3] | |||
| abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth | |||
| abl - INFO - Evaluation ended, hwf/character_accuracy: 0.996 hwf/reasoning_accuracy: 0.977 | |||
| @@ -0,0 +1,381 @@ | |||
| MNIST Addition | |||
| ============== | |||
| .. raw:: html | |||
| <p>For detailed code implementation, please view on <a class="reference external" href="https://github.com/AbductiveLearning/ABL-Package/tree/Dev/examples/mnist_add" target="_blank">GitHub</a>.</p> | |||
| Below shows an implementation of `MNIST | |||
| Addition <https://arxiv.org/abs/1805.10872>`__. In this task, pairs of | |||
| MNIST handwritten images and their sums are given, alongwith a domain | |||
| knowledge base containing information on how to perform addition | |||
| operations. The task is to recognize the digits of handwritten images | |||
| and accurately determine their sum. | |||
| Intuitively, we first use a machine learning model (learning part) to | |||
| convert the input images to digits (we call them pseudo-labels), and | |||
| then use the knowledge base (reasoning part) to calculate the sum of | |||
| these digits. Since we do not have ground-truth of the digits, in | |||
| Abductive Learning, the reasoning part will leverage domain knowledge | |||
| and revise the initial digits yielded by the learning part through | |||
| abductive reasoning. This process enables us to further update the | |||
| machine learning model. | |||
| .. code:: ipython3 | |||
| # Import necessary libraries and modules | |||
| import os.path as osp | |||
| import matplotlib.pyplot as plt | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch.optim import RMSprop, lr_scheduler | |||
| from abl.bridge import SimpleBridge | |||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.reasoning import KBBase, Reasoner | |||
| from abl.utils import ABLLogger, print_log | |||
| from datasets import get_dataset | |||
| from models.nn import LeNet5 | |||
| Working with Data | |||
| ----------------- | |||
| First, we get the training and testing datasets: | |||
| .. code:: ipython3 | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ``train_data`` and ``test_data`` share identical structures: | |||
| tuples with three components: X (list where each element is a | |||
| list of two images), gt_pseudo_label (list where each element | |||
| is a list of two digits, i.e., pseudo-labels) and Y (list where | |||
| each element is the sum of the two digits). The length and structures | |||
| of datasets are illustrated as follows. | |||
| .. note:: | |||
| ``gt_pseudo_label`` is only used to evaluate the performance of | |||
| the learning part but not to train the model. | |||
| .. code:: ipython3 | |||
| print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y") | |||
| print("\n") | |||
| train_X, train_gt_pseudo_label, train_Y = train_data | |||
| print(f"Length of X, gt_pseudo_label, Y in train_data: " + | |||
| f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}") | |||
| test_X, test_gt_pseudo_label, test_Y = test_data | |||
| print(f"Length of X, gt_pseudo_label, Y in test_data: " + | |||
| f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}") | |||
| print("\n") | |||
| X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0] | |||
| print(f"X is a {type(train_X).__name__}, " + | |||
| f"with each element being a {type(X_0).__name__} " + | |||
| f"of {len(X_0)} {type(X_0[0]).__name__}.") | |||
| print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " + | |||
| f"with each element being a {type(gt_pseudo_label_0).__name__} " + | |||
| f"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.") | |||
| print(f"Y is a {type(train_Y).__name__}, " + | |||
| f"with each element being a {type(Y_0).__name__}.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y | |||
| Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000 | |||
| Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000 | |||
| X is a list, with each element being a list of 2 Tensor. | |||
| gt_pseudo_label is a list, with each element being a list of 2 int. | |||
| Y is a list, with each element being a int. | |||
| The ith element of X, gt_pseudo_label, and Y together constitute the ith | |||
| data example. As an illustration, in the first data example of the | |||
| training set, we have: | |||
| .. code:: ipython3 | |||
| X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0] | |||
| print(f"X in the first data example (a list of two images):") | |||
| plt.subplot(1,2,1) | |||
| plt.axis('off') | |||
| plt.imshow(X_0[0].squeeze(), cmap='gray') | |||
| plt.subplot(1,2,2) | |||
| plt.axis('off') | |||
| plt.imshow(X_0[1].squeeze(), cmap='gray') | |||
| plt.show() | |||
| print(f"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}") | |||
| print(f"Y in the first data example (their sum result): {Y_0}") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| X in the first data example (a list of two images): | |||
| .. image:: ../_static/img/mnist_add_datasets.png | |||
| :width: 200px | |||
| .. code:: none | |||
| :class: code-out | |||
| gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): [7, 5] | |||
| Y in the first data example (their sum result): 12 | |||
| Building the Learning Part | |||
| -------------------------- | |||
| To build the learning part, we need to first build a machine learning | |||
| base model. We use a simple `LeNet-5 neural | |||
| network <https://en.wikipedia.org/wiki/LeNet>`__, and encapsulate it | |||
| within a ``BasicNN`` object to create the base model. ``BasicNN`` is a | |||
| class that encapsulates a PyTorch model, transforming it into a base | |||
| model with an sklearn-style interface. | |||
| .. code:: ipython3 | |||
| cls = LeNet5(num_classes=10) | |||
| loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) | |||
| optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100) | |||
| base_model = BasicNN( | |||
| cls, | |||
| loss_fn, | |||
| optimizer, | |||
| scheduler=scheduler, | |||
| device=device, | |||
| batch_size=32, | |||
| num_epochs=1, | |||
| ) | |||
| ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which | |||
| are used to predict the class index and the probabilities of each class | |||
| for images. As shown below: | |||
| .. code:: ipython3 | |||
| data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)] | |||
| pred_idx = base_model.predict(X=data_instances) | |||
| print(f"Predicted class index for a batch of 32 instances: np.ndarray with shape {pred_idx.shape}") | |||
| pred_prob = base_model.predict_proba(X=data_instances) | |||
| print(f"Predicted class probabilities for a batch of 32 instances: np.ndarray with shape {pred_prob.shape}") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Predicted class index for a batch of 32 instances: np.ndarray with shape (32,) | |||
| Predicted class probabilities for a batch of 32 instances: np.ndarray with shape (32, 10) | |||
| However, the base model built above deals with instance-level data | |||
| (i.e., individual images), and can not directly deal with example-level | |||
| data (i.e., a pair of images). Therefore, we wrap the base model into | |||
| ``ABLModel``, which enables the learning part to train, test, and | |||
| predict on example-level data. | |||
| .. code:: ipython3 | |||
| model = ABLModel(base_model) | |||
| As an illustration, consider this example of training on example-level | |||
| data using the ``predict`` method in ``ABLModel``. In this process, the | |||
| method accepts data examples as input and outputs the class labels and | |||
| the probabilities of each class for all instances within these data | |||
| examples. | |||
| .. code:: ipython3 | |||
| from abl.data.structures import ListData | |||
| # ListData is a data structure provided by ABL-Package that can be used to organize data examples | |||
| data_examples = ListData() | |||
| # We use the first 100 data examples in the training set as an illustration | |||
| data_examples.X = train_X[:100] | |||
| data_examples.gt_pseudo_label = train_gt_pseudo_label[:100] | |||
| data_examples.Y = train_Y[:100] | |||
| # Perform prediction on the 100 data examples | |||
| pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob'] | |||
| print(f"Predicted class labels for the 100 data examples: \n" + | |||
| f"a list of length {len(pred_label)}, and each element is " + | |||
| f"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\n") | |||
| print(f"Predicted class probabilities for the 100 data examples: \n" + | |||
| f"a list of length {len(pred_prob)}, and each element is " + | |||
| f"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Predicted class labels for the 100 data examples: | |||
| a list of length 100, and each element is a ndarray of shape (2,). | |||
| Predicted class probabilities for the 100 data examples: | |||
| a list of length 100, and each element is a ndarray of shape (2, 10). | |||
| Building the Reasoning Part | |||
| --------------------------- | |||
| In the reasoning part, we first build a knowledge base which contain | |||
| information on how to perform addition operations. We build it by | |||
| creating a subclass of ``KBBase``. In the derived subclass, we | |||
| initialize the ``pseudo_label_list`` parameter specifying list of | |||
| possible pseudo-labels, and override the ``logic_forward`` function | |||
| defining how to perform (deductive) reasoning. | |||
| .. code:: ipython3 | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| # Implement the deduction function | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| kb = AddKB() | |||
| The knowledge base can perform logical reasoning (both deductive | |||
| reasoning and abductive reasoning). Below is an example of performing | |||
| (deductive) reasoning, and users can refer to :ref:`Performing abductive | |||
| reasoning in the knowledge base <kb-abd>` for details of abductive reasoning. | |||
| .. code:: ipython3 | |||
| pseudo_labels = [1, 2] | |||
| reasoning_result = kb.logic_forward(pseudo_labels) | |||
| print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Reasoning result of pseudo-labels [1, 2] is 3. | |||
| .. note:: | |||
| In addition to building a knowledge base based on ``KBBase``, we | |||
| can also establish a knowledge base with a ground KB using ``GroundKB``, | |||
| or a knowledge base implemented based on Prolog files using | |||
| ``PrologKB``. The corresponding code for these implementations can be | |||
| found in the ``main.py`` file. Those interested are encouraged to | |||
| examine it for further insights. | |||
| Then, we create a reasoner by instantiating the class ``Reasoner``. Due | |||
| to the indeterminism of abductive reasoning, there could be multiple | |||
| candidates compatible to the knowledge base. When this happens, reasoner | |||
| can minimize inconsistencies between the knowledge base and | |||
| pseudo-labels predicted by the learning part, and then return only one | |||
| candidate that has the highest consistency. | |||
| .. code:: ipython3 | |||
| reasoner = Reasoner(kb) | |||
| .. note:: | |||
| During creating reasoner, the definition of “consistency” can be | |||
| customized within the ``dist_func`` parameter. In the code above, we | |||
| employ a consistency measurement based on confidence, which calculates | |||
| the consistency between the data example and candidates based on the | |||
| confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we | |||
| provide options for utilizing other forms of consistency measurement. | |||
| Also, during process of inconsistency minimization, we can leverage | |||
| `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration. | |||
| Options for this are also available in ``examples/mnist_add/main.py``. Those interested are | |||
| encouraged to explore these features. | |||
| Building Evaluation Metrics | |||
| --------------------------- | |||
| Next, we set up evaluation metrics. These metrics will be used to | |||
| evaluate the model performance during training and testing. | |||
| Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are | |||
| used to evaluate the accuracy of the machine learning model’s | |||
| predictions and the accuracy of the final reasoning results, | |||
| respectively. | |||
| .. code:: ipython3 | |||
| metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
| Bridge Learning and Reasoning | |||
| ----------------------------- | |||
| Now, the last step is to bridge the learning and reasoning part. We | |||
| proceed this step by creating an instance of ``SimpleBridge``. | |||
| .. code:: ipython3 | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| Perform training and testing by invoking the ``train`` and ``test`` | |||
| methods of ``SimpleBridge``. | |||
| .. code:: ipython3 | |||
| # Build logger | |||
| print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir) | |||
| bridge.test(test_data) | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| abl - INFO - Abductive Learning on the MNIST Addition example. | |||
| abl - INFO - loop(train) [1/1] segment(train) [1/100] | |||
| abl - INFO - model loss: 2.23587 | |||
| abl - INFO - loop(train) [1/1] segment(train) [2/100] | |||
| abl - INFO - model loss: 2.23756 | |||
| abl - INFO - loop(train) [1/1] segment(train) [3/100] | |||
| abl - INFO - model loss: 2.04475 | |||
| abl - INFO - loop(train) [1/1] segment(train) [4/100] | |||
| abl - INFO - model loss: 2.01035 | |||
| abl - INFO - loop(train) [1/1] segment(train) [5/100] | |||
| abl - INFO - model loss: 1.97584 | |||
| abl - INFO - loop(train) [1/1] segment(train) [6/100] | |||
| abl - INFO - model loss: 1.91570 | |||
| abl - INFO - loop(train) [1/1] segment(train) [7/100] | |||
| abl - INFO - model loss: 1.90268 | |||
| abl - INFO - loop(train) [1/1] segment(train) [8/100] | |||
| abl - INFO - model loss: 1.77436 | |||
| abl - INFO - loop(train) [1/1] segment(train) [9/100] | |||
| abl - INFO - model loss: 1.73454 | |||
| abl - INFO - loop(train) [1/1] segment(train) [10/100] | |||
| abl - INFO - model loss: 1.62495 | |||
| abl - INFO - loop(train) [1/1] segment(train) [11/100] | |||
| abl - INFO - model loss: 1.58456 | |||
| abl - INFO - loop(train) [1/1] segment(train) [12/100] | |||
| abl - INFO - model loss: 1.62575 | |||
| ... | |||
| abl - INFO - Eval start: loop(val) [1] | |||
| abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.986 mnist_add/reasoning_accuracy: 0.973 | |||
| abl - INFO - Saving model: loop(save) [1] | |||
| abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_1.pth | |||
| abl - INFO - Test start: | |||
| abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.983 mnist_add/reasoning_accuracy: 0.967 | |||
| @@ -0,0 +1,255 @@ | |||
| Zoo | |||
| === | |||
| .. raw:: html | |||
| <p>For detailed code implementation, please view on <a class="reference external" href="https://github.com/AbductiveLearning/ABL-Package/tree/Dev/examples/zoo" target="_blank">GitHub</a>.</p> | |||
| Below shows an implementation of | |||
| `Zoo <https://archive.ics.uci.edu/dataset/111/zoo>`__ dataset. In this task, | |||
| attributes of animals (such as presence of hair, eggs, etc.) and their | |||
| targets (the animal class they belong to) are given, along with a | |||
| knowledge base which contain information about the relations between | |||
| attributes and targets, e.g., Implies(milk == 1, mammal == 1). | |||
| The goal of this task is to develop a learning model that can predict | |||
| the targets of animals based on their attributes. In the initial stages, | |||
| when the model is under-trained, it may produce incorrect predictions | |||
| that conflict with the relations contained in the knowledge base. When | |||
| this happens, abductive reasoning can be employed to adjust these | |||
| results and retrain the model accordingly. This process enables us to | |||
| further update the learning model. | |||
| .. code:: ipython3 | |||
| # Import necessary libraries and modules | |||
| import os.path as osp | |||
| import numpy as np | |||
| from sklearn.ensemble import RandomForestClassifier | |||
| from abl.bridge import SimpleBridge | |||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| from abl.learning import ABLModel | |||
| from abl.reasoning import Reasoner | |||
| from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple | |||
| from get_dataset import load_and_preprocess_dataset, split_dataset | |||
| from kb import ZooKB | |||
| Working with Data | |||
| ----------------- | |||
| First, we load and preprocess the `Zoo | |||
| dataset <https://archive.ics.uci.edu/dataset/111/zoo>`__, and split it | |||
| into labeled/unlabeled/test data | |||
| .. code:: ipython3 | |||
| X, y = load_and_preprocess_dataset(dataset_id=62) | |||
| X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | |||
| Zoo dataset consist of tabular data. The attributes contains 17 boolean | |||
| values (e.g., hair, feathers, eggs, milk, airborne, aquatic, etc.) and | |||
| the target is a integer value in range [0,6] representing 7 classes | |||
| (e.g., mammal, bird, reptile, fish, amphibian, insect, and other). Below | |||
| is an illustration: | |||
| .. code:: ipython3 | |||
| print("Shape of X and y:", X.shape, y.shape) | |||
| print("First five elements of X:") | |||
| print(X[:5]) | |||
| print("First five elements of y:") | |||
| print(y[:5]) | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Shape of X and y: (101, 16) (101,) | |||
| First five elements of X: | |||
| [[True False False True False False True True True True False False 4 | |||
| False False True] | |||
| [True False False True False False False True True True False False 4 | |||
| True False True] | |||
| [False False True False False True True True True False False True 0 | |||
| True False False] | |||
| [True False False True False False True True True True False False 4 | |||
| False False True] | |||
| [True False False True False False True True True True False False 4 | |||
| True False True]] | |||
| First five elements of y: | |||
| [0 0 3 0 0] | |||
| Next, we transform the tabular data to the format required by | |||
| ABL-Package, which is a tuple of (X, gt_pseudo_label, Y). In this task, | |||
| we treat the attributes as X and the targets as gt_pseudo_label (ground | |||
| truth pseudo-labels). Y (reasoning results) are expected to be 0, | |||
| indicating no rules are violated. | |||
| .. code:: ipython3 | |||
| label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0) | |||
| data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0) | |||
| train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0) | |||
| Building the Learning Part | |||
| -------------------------- | |||
| To build the learning part, we need to first build a machine learning | |||
| base model. We use a `Random | |||
| Forest <https://en.wikipedia.org/wiki/Random_forest>`__ as the base | |||
| model. | |||
| .. code:: ipython3 | |||
| base_model = RandomForestClassifier() | |||
| However, the base model built above deals with instance-level data, and | |||
| can not directly deal with example-level data. Therefore, we wrap the | |||
| base model into ``ABLModel``, which enables the learning part to train, | |||
| test, and predict on example-level data. | |||
| .. code:: ipython3 | |||
| model = ABLModel(base_model) | |||
| Building the Reasoning Part | |||
| --------------------------- | |||
| In the reasoning part, we first build a knowledge base which contains | |||
| information about the relations between attributes (X) and targets | |||
| (pseudo-labels), e.g., Implies(milk == 1, mammal == 1). The knowledge | |||
| base is built in the ``ZooKB`` class within file ``examples/zoo/kb.py``, and is | |||
| derived from the ``KBBase`` class. | |||
| .. code:: ipython3 | |||
| kb = ZooKB() | |||
| As mentioned, for all attributes and targets in the dataset, the | |||
| reasoning results are expected to be 0 since there should be no | |||
| violations of the established knowledge in real data. As shown below: | |||
| .. code:: ipython3 | |||
| for idx, (x, y_item) in enumerate(zip(X[:5], y[:5])): | |||
| print(f"Example {idx}: the attributes are: {x}, and the target is {y_item}.") | |||
| print(f"Reasoning result is {kb.logic_forward([y_item], [x])}.") | |||
| print() | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Example 0: the attributes are: [True False False True False False True True True True False False 4 False | |||
| False True], and the target is 0. | |||
| Reasoning result is 0. | |||
| Example 1: the attributes are: [True False False True False False False True True True False False 4 True | |||
| False True], and the target is 0. | |||
| Reasoning result is 0. | |||
| Example 2: the attributes are: [False False True False False True True True True False False True 0 True | |||
| False False], and the target is 3. | |||
| Reasoning result is 0. | |||
| Example 3: the attributes are: [True False False True False False True True True True False False 4 False | |||
| False True], and the target is 0. | |||
| Reasoning result is 0. | |||
| Example 4: the attributes are: [True False False True False False True True True True False False 4 True | |||
| False True], and the target is 0. | |||
| Reasoning result is 0. | |||
| Then, we create a reasoner by instantiating the class ``Reasoner``. Due | |||
| to the indeterminism of abductive reasoning, there could be multiple | |||
| candidates compatible to the knowledge base. When this happens, reasoner | |||
| can minimize inconsistencies between the knowledge base and | |||
| pseudo-labels predicted by the learning part, and then return only one | |||
| candidate that has the highest consistency. | |||
| .. code:: ipython3 | |||
| def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
| pred_prob = data_example.pred_prob | |||
| model_scores = confidence_dist(pred_prob, candidate_idxs) | |||
| rule_scores = np.array(reasoning_results) | |||
| scores = model_scores + rule_scores | |||
| return scores | |||
| reasoner = Reasoner(kb, dist_func=consitency) | |||
| Building Evaluation Metrics | |||
| --------------------------- | |||
| Next, we set up evaluation metrics. These metrics will be used to | |||
| evaluate the model performance during training and testing. | |||
| Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which | |||
| are used to evaluate the accuracy of the machine learning model’s | |||
| predictions and the accuracy of the final reasoning results, | |||
| respectively. | |||
| .. code:: ipython3 | |||
| metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] | |||
| Bridging Learning and Reasoning | |||
| ------------------------------- | |||
| Now, the last step is to bridge the learning and reasoning part. We | |||
| proceed this step by creating an instance of ``SimpleBridge``. | |||
| .. code:: ipython3 | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| Perform training and testing by invoking the ``train`` and ``test`` | |||
| methods of ``SimpleBridge``. | |||
| .. code:: ipython3 | |||
| # Build logger | |||
| print_log("Abductive Learning on the Zoo example.", logger="current") | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| print_log("------- Use labeled data to pretrain the model -----------", logger="current") | |||
| base_model.fit(X_label, y_label) | |||
| print_log("------- Test the initial model -----------", logger="current") | |||
| bridge.test(test_data) | |||
| print_log("------- Use ABL to train the model -----------", logger="current") | |||
| bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir) | |||
| print_log("------- Test the final model -----------", logger="current") | |||
| bridge.test(test_data) | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| abl - INFO - Abductive Learning on the ZOO example. | |||
| abl - INFO - ------- Use labeled data to pretrain the model ----------- | |||
| abl - INFO - ------- Test the initial model ----------- | |||
| abl - INFO - Evaluation ended, zoo/character_accuracy: 0.903 zoo/reasoning_accuracy: 0.903 | |||
| abl - INFO - ------- Use ABL to train the model ----------- | |||
| abl - INFO - loop(train) [1/3] segment(train) [1/1] | |||
| abl - INFO - Evaluation start: loop(val) [1] | |||
| abl - INFO - Evaluation ended, zoo/character_accuracy: 1.000 zoo/reasoning_accuracy: 1.000 | |||
| abl - INFO - loop(train) [2/3] segment(train) [1/1] | |||
| abl - INFO - Evaluation start: loop(val) [2] | |||
| abl - INFO - Evaluation ended, zoo/character_accuracy: 1.000 zoo/reasoning_accuracy: 1.000 | |||
| abl - INFO - loop(train) [3/3] segment(train) [1/1] | |||
| abl - INFO - Evaluation start: loop(val) [3] | |||
| abl - INFO - Evaluation ended, zoo/character_accuracy: 1.000 zoo/reasoning_accuracy: 1.000 | |||
| abl - INFO - ------- Test the final model ----------- | |||
| abl - INFO - Evaluation ended, zoo/character_accuracy: 0.968 zoo/reasoning_accuracy: 0.968 | |||
| We may see from the results, after undergoing training with ABL, the | |||
| model’s accuracy has improved. | |||
| @@ -0,0 +1,95 @@ | |||
| **Learn the Basics** || | |||
| `Quick Start <Quick-Start.html>`_ || | |||
| `Dataset & Data Structure <Datasets.html>`_ || | |||
| `Learning Part <Learning.html>`_ || | |||
| `Reasoning Part <Reasoning.html>`_ || | |||
| `Evaluation Metrics <Evaluation.html>`_ || | |||
| `Bridge <Bridge.html>`_ | |||
| Learn the Basics | |||
| ================ | |||
| Modules in ABL-Package | |||
| ---------------------- | |||
| ABL-Package is an efficient implementation of `Abductive Learning <../Overview/Abductive-Learning.html>`_ (ABL), | |||
| a paradigm which integrates machine learning and logical reasoning in a balanced-loop. | |||
| The ABL-Package comprises three primary parts: **Data**, **Learning**, and | |||
| **Reasoning**, corresponding to the three pivotal components of current | |||
| AI: data, models, and knowledge. Below is an overview of the ABL-Package. | |||
| .. image:: ../_static/img/ABL-Package.png | |||
| **Data** part manages the storage, operation, and evaluation of data efficiently. | |||
| It includes the ``ListData`` class, which defines the data structures used in | |||
| ABL, and comprises common data operations like insertion, deletion, | |||
| retrieval, slicing, etc. Additionally, it contains a series of evaluation metrics | |||
| such as ``SymbolAccuracy`` and ``ReasoningMetric`` (both specialized metrics | |||
| inherited from the ``BaseMetric`` class), for evaluating model quality from a | |||
| data perspective. | |||
| :blue-bold:`Learning` part focuses on the construction, training, and | |||
| prediction of machine learning models. The ``ABLModel`` class is the | |||
| central class that encapsulates the machine learning model. This class is | |||
| compatible with various frameworks, including those based on Scikit-learn | |||
| or PyTorch neural networks constructed by the ``BasicNN`` class. | |||
| :green-bold:`Reasoning` part concentrates on constructing domain knowledge and | |||
| performing reasoning. The ``KBBase`` class allows users to define a | |||
| domain knowledge base. For diverse types of knowledge, we also offer | |||
| implementations like ``GroundKB`` and ``PrologKB`` (both inherited | |||
| from the ``KBBase`` class). The latter, for instance, enables | |||
| knowledge bases to be imported in the form of Prolog files. | |||
| Upon building the knowledge base, the ``Reasoner`` class is | |||
| responsible for minimizing the inconsistency between the knowledge base | |||
| and data. | |||
| The integration of these three parts are achieved through the | |||
| :yellow-bold:`Bridge` part, which features the ``SimpleBridge`` class (derived | |||
| from the ``BaseBridge`` class). The Bridge part synthesizes data, | |||
| learning, and reasoning, facilitating the training and testing | |||
| of the entire ABL framework. | |||
| Use ABL-Package Step by Step | |||
| ---------------------------- | |||
| In a typical ABL process, as illustrated below, | |||
| data inputs are first predicted by the learning model ``ABLModel.predict``, and the outcomes are pseudo-labels. | |||
| These labels then pass through deductive reasoning of the domain knowledge base ``KBBase.logic_forward`` | |||
| to obtain the reasoning result. During training, | |||
| alongside the aforementioned forward flow (i.e., prediction --> deduction reasoning), | |||
| there also exists a reverse flow, which starts from the reasoning result and | |||
| involves abductive reasoning ``KBBase.abduce_candidates`` to generate possible revised pseudo-labels. | |||
| Subsequently, these pseudo-labels are processed to minimize inconsistencies with the learning part, | |||
| which in turn revise the outcomes of the learning model, and then | |||
| fed back for further training ``ABLModel.train``. | |||
| .. image:: ../_static/img/usage.png | |||
| To implement this process, the following five steps are necessary: | |||
| 1. Prepare **datasets** | |||
| Prepare the data's input, ground truth for pseudo-labels (optional), and ground truth for reasoning results. | |||
| 2. :blue:`Build the` :blue-bold:`learning` :blue:`part` | |||
| Build a machine learning base model that can predict inputs to pseudo-labels. | |||
| Then, use ``ABLModel`` to encapsulate the base model. | |||
| 3. :green:`Build the` :green-bold:`reasoning` :green:`part` | |||
| Define a knowledge base by building a subclass of ``KBBase``, specifying how to | |||
| process pseudo-labels to reasoning results. | |||
| Also, create a ``Reasoner`` for minimizing inconsistencies | |||
| between the knowledge base and data. | |||
| 4. Define evaluation metrics | |||
| Define the metrics by building a subclass of ``BaseMetric``. The metrics will | |||
| specify how to measure performance during the training and testing of the ABL framework. | |||
| 5. :yellow-bold:`Bridge` :yellow:`learning and reasoning` | |||
| Use ``SimpleBridge`` to bridge the learning and reasoning part | |||
| for integrated training and testing. | |||
| @@ -0,0 +1,95 @@ | |||
| `Learn the Basics <Basics.html>`_ || | |||
| `Quick Start <Quick-Start.html>`_ || | |||
| `Dataset & Data Structure <Datasets.html>`_ || | |||
| `Learning Part <Learning.html>`_ || | |||
| `Reasoning Part <Reasoning.html>`_ || | |||
| `Evaluation Metrics <Evaluation.html>`_ || | |||
| **Bridge** | |||
| Bridge | |||
| ====== | |||
| In this section, we will look at how to bridge learning and reasoning parts to train the model, which is the fundamental idea of Abductive Learning. ABL-Package implements a set of bridge classes to achieve this. | |||
| .. code:: python | |||
| from abl.bridge import BaseBridge, SimpleBridge | |||
| ``BaseBridge`` is an abstract class with the following initialization parameters: | |||
| - ``model`` is an object of type ``ABLModel``. Learning part are wrapped in this object. | |||
| - ``reasoner`` is a object of type ``Reasoner``. Reasoning part are wrapped in this object. | |||
| ``BaseBridge`` has the following important methods that need to be overridden in subclasses: | |||
| +---------------------------------------+----------------------------------------------------+ | |||
| | Method Signature | Description | | |||
| +=======================================+====================================================+ | |||
| | ``predict(data_examples)`` | Predicts class probabilities and indices | | |||
| | | for the given data examples. | | |||
| +---------------------------------------+----------------------------------------------------+ | |||
| | ``abduce_pseudo_label(data_examples)``| Abduces pseudo-labels for the given data examples. | | |||
| +---------------------------------------+----------------------------------------------------+ | |||
| | ``idx_to_pseudo_label(data_examples)``| Converts indices to pseudo-labels using | | |||
| | | the provided or default mapping. | | |||
| +---------------------------------------+----------------------------------------------------+ | |||
| | ``pseudo_label_to_idx(data_examples)``| Converts pseudo-labels to indices | | |||
| | | using the provided or default remapping. | | |||
| +---------------------------------------+----------------------------------------------------+ | |||
| | ``train(train_data)`` | Train the model. | | |||
| +---------------------------------------+----------------------------------------------------+ | |||
| | ``test(test_data)`` | Test the model. | | |||
| +---------------------------------------+----------------------------------------------------+ | |||
| where ``train_data`` and ``test_data`` are both in the form of a tuple or a `ListData <../API/abl.data.html#structures.ListData>`_. Regardless of the form, they all need to include three components: ``X``, ``gt_pseudo_label`` and ``Y``. Since ``ListData`` is the underlying data structure used throughout the ABL-Package, tuple-formed data will be firstly transformed into ``ListData`` in the ``train`` and ``test`` methods, and such ``ListData`` instances are referred to as ``data_examples``. More details can be found in `preparing datasets <Datasets.html>`_. | |||
| ``SimpleBridge`` inherits from ``BaseBridge`` and provides a basic implementation. Besides the ``model`` and ``reasoner``, ``SimpleBridge`` has an extra initialization arguments, ``metric_list``, which will be used to evaluate model performance. Its training process involves several Abductive Learning loops and each loop consists of the following five steps: | |||
| 1. Predict class probabilities and indices for the given data examples. | |||
| 2. Transform indices into pseudo-labels. | |||
| 3. Revise pseudo-labels based on abdutive reasoning. | |||
| 4. Transform the revised pseudo-labels to indices. | |||
| 5. Train the model. | |||
| The fundamental part of the ``train`` method is as follows: | |||
| .. code-block:: python | |||
| def train(self, train_data, loops=50, segment_size=10000): | |||
| """ | |||
| Parameters | |||
| ---------- | |||
| train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] | |||
| Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` | |||
| object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. | |||
| - ``X`` is a list of sublists representing the input data. | |||
| - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but not | |||
| to train. ``gt_pseudo_label`` can be ``None``. | |||
| - ``Y`` is a list representing the ground truth reasoning result for each sublist in ``X``. | |||
| loops : int | |||
| Machine Learning part and Reasoning part will be iteratively optimized | |||
| for ``loops`` times. | |||
| segment_size : Union[int, float] | |||
| Data will be split into segments of this size and data in each segment | |||
| will be used together to train the model. | |||
| """ | |||
| if isinstance(train_data, ListData): | |||
| data_examples = train_data | |||
| else: | |||
| data_examples = self.data_preprocess(*train_data) | |||
| if isinstance(segment_size, float): | |||
| segment_size = int(segment_size * len(data_examples)) | |||
| for loop in range(loops): | |||
| for seg_idx in range((len(data_examples) - 1) // segment_size + 1): | |||
| sub_data_examples = data_examples[ | |||
| seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
| ] | |||
| self.predict(sub_data_examples) # 1 | |||
| self.idx_to_pseudo_label(sub_data_examples) # 2 | |||
| self.abduce_pseudo_label(sub_data_examples) # 3 | |||
| self.pseudo_label_to_idx(sub_data_examples) # 4 | |||
| loss = self.model.train(sub_data_examples) # 5, self.model is an ABLModel object | |||
| @@ -0,0 +1,89 @@ | |||
| `Learn the Basics <Basics.html>`_ || | |||
| `Quick Start <Quick-Start.html>`_ || | |||
| **Dataset & Data Structure** || | |||
| `Learning Part <Learning.html>`_ || | |||
| `Reasoning Part <Reasoning.html>`_ || | |||
| `Evaluation Metrics <Evaluation.html>`_ || | |||
| `Bridge <Bridge.html>`_ | |||
| Dataset & Data Structure | |||
| ======================== | |||
| In this section, we will look at the dataset and data structure in ABL-Package. | |||
| .. code:: python | |||
| import torch | |||
| from abl.data.structures import ListData | |||
| Dataset | |||
| ------- | |||
| ABL-Package requires user data to be either structured as a tuple ``(X, gt_pseudo_label, Y)`` or a ``ListData`` (the underlying data structure utilized in ABL-Package, cf. the next section) object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Regardless of the chosen format, the data should encompass three essential components: | |||
| - ``X``: List[List[Any]] | |||
| A list of sublists representing the input data. We refer to each sublist in ``X`` as an **example** and each example may contain several **instances**. | |||
| - ``gt_pseudo_label``: List[List[Any]], optional | |||
| A list of sublists with each sublist representing ground-truth pseudo-labels of an example. Each pseudo-label in the sublist serves as ground-truth for each **instance** within the example. | |||
| .. note:: | |||
| ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model. If the pseudo-label of the instances in the datasets are unlabeled, ``gt_pseudo_label`` should be ``None``. | |||
| - ``Y``: List[Any] | |||
| A list representing the ground-truth reasoning result for each **example** in ``X``. | |||
| .. warning:: | |||
| The length of ``X``, ``gt_pseudo_label`` (if not ``None``) and ``Y`` should be the same. Also, each sublist in ``gt_pseudo_label`` should have the same length as the sublist in ``X``. | |||
| As an illustration, in the MNIST Addition task, the data are organized as follows: | |||
| .. image:: ../_static/img/Datasets_1.png | |||
| :width: 350px | |||
| :align: center | |||
| .. |data_example| image:: ../_static/img/data_example.png | |||
| :alt: alternate text | |||
| :scale: 8% | |||
| .. |instance| image:: ../_static/img/instance.png | |||
| :alt: alternate text | |||
| :scale: 55% | |||
| where each sublist in ``X``, e.g., |data_example|, is a data example and each image in the sublist, e.g., |instance|, is an instance. | |||
| Data Structure | |||
| -------------- | |||
| Besides the user-provided dataset, various forms of data are utilized and dynamicly generated throughout the training and testing process of ABL framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, etc. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.data.html#structure>`_ to encapsulate different forms of data that will be used in the total learning process. | |||
| ``ListData`` is the underlying abstract data interface utilized in ABL-Package. As the fundamental data structure, ``ListData`` implements commonly used data manipulation methods and is responsible for transferring data between various components of ABL, ensuring that stages such as prediction, abductive reasoning, and training can utilize ``ListData`` as a unified input format. Before proceeding to other stages, user-provided datasets will be firstly converted into ``ListData``. | |||
| Besides providing a tuple of ``(X, gt_pseudo_label, Y)``, ABL-Package also allows users to directly supply data in ``ListData`` format, which similarly requires the inclusion of these three attributes. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.data.html#structure>`_. | |||
| .. code-block:: python | |||
| # Prepare data | |||
| X = [list(torch.randn(3, 28, 28)), list(torch.randn(3, 28, 28))] | |||
| gt_pseudo_label = [[1, 2, 3], [4, 5, 6]] | |||
| Y = [1, 2] | |||
| # Convert data into ListData | |||
| data = ListData(X=X, Y=Y, gt_pseudo_label=gt_pseudo_label) | |||
| # Get data | |||
| X = data.X | |||
| Y = data.Y | |||
| gt_pseudo_label = data.gt_pseudo_label | |||
| # Set data | |||
| data.X = X | |||
| data.Y = Y | |||
| data.gt_pseudo_label = gt_pseudo_label | |||
| @@ -0,0 +1,52 @@ | |||
| `Learn the Basics <Basics.html>`_ || | |||
| `Quick Start <Quick-Start.html>`_ || | |||
| `Dataset & Data Structure <Datasets.html>`_ || | |||
| `Learning Part <Learning.html>`_ || | |||
| `Reasoning Part <Reasoning.html>`_ || | |||
| **Evaluation Metrics** || | |||
| `Bridge <Bridge.html>`_ | |||
| Evaluation Metrics | |||
| ================== | |||
| In this section, we will look at how to build evaluation metrics. | |||
| .. code:: python | |||
| from abl.data.evaluation import BaseMetric, SymbolAccuracy, ReasoningMetric | |||
| ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing. | |||
| To customize our own metrics, we need to inherit from ``BaseMetric`` and implement the ``process`` and ``compute_metrics`` methods. | |||
| - The ``process`` method accepts a batch of model prediction and saves the information to ``self.results`` property after processing this batch. | |||
| - The ``compute_metrics`` method uses all the information saved in ``self.results`` to calculate and return a dict that holds the evaluation results. | |||
| Besides, we can assign a ``str`` to the ``prefix`` argument of the ``__init__`` function. This string is automatically prefixed to the output metric names. For example, if we set ``prefix="mnist_add"``, the output metric name will be ``character_accuracy``. | |||
| We provide two basic metrics, namely ``SymbolAccuracy`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the final reasoning results, respectively. Using ``SymbolAccuracy`` as an example, the following code shows how to implement a custom metrics. | |||
| .. code:: python | |||
| class SymbolAccuracy(BaseMetric): | |||
| def __init__(self, prefix: Optional[str] = None) -> None: | |||
| # prefix is used to distinguish different metrics | |||
| super().__init__(prefix) | |||
| def process(self, data_examples: Sequence[dict]) -> None: | |||
| # pred_pseudo_label and gt_pseudo_label are both of type List[List[Any]] | |||
| # and have the same length | |||
| pred_pseudo_label = data_examples.pred_pseudo_label | |||
| gt_pseudo_label = data_examples.gt_pseudo_label | |||
| for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): | |||
| correct_num = 0 | |||
| for pred_symbol, symbol in zip(pred_z, z): | |||
| if pred_symbol == symbol: | |||
| correct_num += 1 | |||
| self.results.append(correct_num / len(z)) | |||
| def compute_metrics(self, results: list) -> dict: | |||
| metrics = dict() | |||
| metrics["character_accuracy"] = sum(results) / len(results) | |||
| return metrics | |||
| @@ -0,0 +1,86 @@ | |||
| `Learn the Basics <Basics.html>`_ || | |||
| `Quick Start <Quick-Start.html>`_ || | |||
| `Dataset & Data Structure <Datasets.html>`_ || | |||
| **Learning Part** || | |||
| `Reasoning Part <Reasoning.html>`_ || | |||
| `Evaluation Metrics <Evaluation.html>`_ || | |||
| `Bridge <Bridge.html>`_ | |||
| Learning Part | |||
| ============= | |||
| In this section, we will look at how to build the learning part. | |||
| In ABL-Package, building the learning part involves two steps: | |||
| 1. Build a machine learning base model used to make predictions on instance-level data. | |||
| 2. Instantiate an ``ABLModel`` with the base model, which enables the learning part to process example-level data. | |||
| .. code:: python | |||
| import sklearn | |||
| import torchvision | |||
| from abl.learning import BasicNN, ABLModel | |||
| Building a base model | |||
| --------------------- | |||
| ABL package allows the base model to be one of the following forms: | |||
| 1. Any machine learning model conforming to the scikit-learn style, i.e., models which has implemented the ``fit`` and ``predict`` methods; | |||
| 2. A PyTorch-based neural network, provided it has defined the architecture and implemented the ``forward`` method. | |||
| For a scikit-learn model, we can directly use the model itself as a base model. For example, we can customize our base model by a KNN classfier: | |||
| .. code:: python | |||
| base_model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=3) | |||
| For a PyTorch-based neural network, we need to encapsulate it within a ``BasicNN`` object to create a base model. For example, we can customize our base model by a pre-trained ResNet-18: | |||
| .. code:: python | |||
| # Load a PyTorch-based neural network | |||
| cls = torchvision.models.resnet18(pretrained=True) | |||
| # loss function and optimizer are used for training | |||
| loss_fn = torch.nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(cls.parameters()) | |||
| base_model = BasicNN(cls, loss_fn, optimizer) | |||
| BasicNN | |||
| ^^^^^^^ | |||
| ``BasicNN`` is a wrapper class for PyTorch-based neural networks, which enables them to work as scikit-learn models. It encapsulates the neural network, loss function, optimizer, and other elements into a single object, which can be used as a base model. | |||
| Besides the necessary methods required to instantiate an ``ABLModel``, i.e., ``fit`` and ``predict``, ``BasicNN`` also implements the following methods: | |||
| +-------------------------------+------------------------------------------+ | |||
| | Method | Function | | |||
| +===============================+==========================================+ | |||
| | ``train_epoch(data_loader)`` | Train the neural network for one epoch. | | |||
| +-------------------------------+------------------------------------------+ | |||
| | ``predict_proba(X)`` | Predict the class probabilities of ``X``.| | |||
| +-------------------------------+------------------------------------------+ | |||
| | ``score(X, y)`` | Calculate the accuracy of the model on | | |||
| | | test data. | | |||
| +-------------------------------+------------------------------------------+ | |||
| | ``save(epoch_id, save_path)`` | Save the model. | | |||
| +-------------------------------+------------------------------------------+ | |||
| | ``load(load_path)`` | Load the model. | | |||
| +-------------------------------+------------------------------------------+ | |||
| Instantiating an ABLModel | |||
| ------------------------- | |||
| Typically, base model is trained to make predictions on instance-level data, and can not directly process example-level data, which is not suitable for most neural-symbolic tasks. ABL-Package provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all base models, which enables the learning part to train, test, and predict on example-level data. | |||
| Generally, we can simply instantiate an ``ABLModel`` by: | |||
| .. code:: python | |||
| # Instantiate an ABLModel | |||
| model = ABLModel(base_model) | |||
| @@ -0,0 +1,129 @@ | |||
| `Learn the Basics <Basics.html>`_ || | |||
| **Quick Start** || | |||
| `Dataset & Data Structure <Datasets.html>`_ || | |||
| `Learning Part <Learning.html>`_ || | |||
| `Reasoning Part <Reasoning.html>`_ || | |||
| `Evaluation Metrics <Evaluation.html>`_ || | |||
| `Bridge <Bridge.html>`_ | |||
| Quick Start | |||
| =========== | |||
| We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contain information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. Refer to the links in each section to dive deeper. | |||
| Working with Data | |||
| ----------------- | |||
| ABL-Package requires data in the format of ``(X, gt_pseudo_label, Y)`` where ``X`` is a list of input examples containing instances, | |||
| ``gt_pseudo_label`` is the ground-truth label of each example in ``X`` and ``Y`` is the ground-truth reasoning result of each example in ``X``. Note that ``gt_pseudo_label`` is only used to evaluate the machine learning model's performance but not to train it. | |||
| In the MNIST Addition task, the data loading looks like | |||
| .. code:: python | |||
| # The 'datasets' module below is located in 'examples/mnist_add/' | |||
| from datasets import get_dataset | |||
| # train_data and test_data are tuples in the format of (X, gt_pseudo_label, Y) | |||
| train_data = get_dataset(train=True) | |||
| test_data = get_dataset(train=False) | |||
| Read more about `preparing datasets <Datasets.html>`_. | |||
| Building the Learning Part | |||
| -------------------------- | |||
| Learning part is constructed by first defining a base model for machine learning. The ABL-Package offers considerable flexibility, supporting any base model that conforms to the scikit-learn style (which requires the implementation of ``fit`` and ``predict`` methods), or a PyTorch-based neural network (which has defined the architecture and implemented ``forward`` method). | |||
| In this example, we build a simple LeNet5 network as the base model. | |||
| .. code:: python | |||
| # The 'models' module below is located in 'examples/mnist_add/' | |||
| from models.nn import LeNet5 | |||
| cls = LeNet5(num_classes=10) | |||
| To facilitate uniform processing, ABL-Package provides the ``BasicNN`` class to convert a PyTorch-based neural network into a format compatible with scikit-learn models. To construct a ``BasicNN`` instance, aside from the network itself, we also need to define a loss function, an optimizer, and the computing device. | |||
| .. code:: python | |||
| import torch | |||
| from abl.learning import BasicNN | |||
| loss_fn = torch.nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001) | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| base_model = BasicNN(model=cls, loss_fn=loss_fn, optimizer=optimizer, device=device) | |||
| The base model built above are trained to make predictions on instance-level data (e.g., a single image), while ABL deals with example-level data. To bridge this gap, we wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise an equation). | |||
| .. code:: python | |||
| from abl.learning import ABLModel | |||
| model = ABLModel(base_model) | |||
| Read more about `building the learning part <Learning.html>`_. | |||
| Building the Reasoning Part | |||
| --------------------------- | |||
| To build the reasoning part, we first define a knowledge base by creating a subclass of ``KBBase``. In the subclass, we initialize the ``pseudo_label_list`` parameter and override the ``logic_forward`` method, which specifies how to perform (deductive) reasoning that processes pseudo-labels of an example to the corresponding reasoning result. Specifically for the MNIST Addition task, this ``logic_forward`` method is tailored to execute the sum operation. | |||
| .. code:: python | |||
| from abl.reasoning import KBBase | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| kb = AddKB() | |||
| Next, we create a reasoner by instantiating the class ``Reasoner``, passing the knowledge base as a parameter. | |||
| Due to the indeterminism of abductive reasoning, there could be multiple candidate pseudo-labels compatible to the knowledge base. | |||
| In such scenarios, the reasoner can minimize inconsistency and return the pseudo-label with the highest consistency. | |||
| .. code:: python | |||
| from abl.reasoning import Reasoner | |||
| reasoner = Reasoner(kb) | |||
| Read more about `building the reasoning part <Reasoning.html>`_. | |||
| Building Evaluation Metrics | |||
| --------------------------- | |||
| ABL-Package provides two basic metrics, namely ``SymbolAccuracy`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively. | |||
| .. code:: python | |||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| metric_list = [SymbolAccuracy(), ReasoningMetric(kb=kb)] | |||
| Read more about `building evaluation metrics <Evaluation.html>`_ | |||
| Bridging Learning and Reasoning | |||
| --------------------------------------- | |||
| Now, we use ``SimpleBridge`` to combine learning and reasoning in a unified ABL framework. | |||
| .. code:: python | |||
| from abl.bridge import SimpleBridge | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| Finally, we proceed with training and testing. | |||
| .. code:: python | |||
| bridge.train(train_data, loops=1, segment_size=0.01) | |||
| bridge.test(test_data) | |||
| Read more about `bridging machine learning and reasoning <Bridge.html>`_. | |||
| @@ -0,0 +1,381 @@ | |||
| `Learn the Basics <Basics.html>`_ || | |||
| `Quick Start <Quick-Start.html>`_ || | |||
| `Dataset & Data Structure <Datasets.html>`_ || | |||
| `Learning Part <Learning.html>`_ || | |||
| **Reasoning Part** || | |||
| `Evaluation Metrics <Evaluation.html>`_ || | |||
| `Bridge <Bridge.html>`_ | |||
| Reasoning part | |||
| =============== | |||
| In this section, we will look at how to build the reasoning part, which | |||
| leverage domain knowledge and perform deductive or abductive reasoning. | |||
| In ABL-Package, building the reasoning part involves two steps: | |||
| 1. Build a knowledge base by creating a subclass of ``KBBase``, which | |||
| specifies how to process pseudo-label of an example to the reasoning result. | |||
| 2. Create a reasoner by instantiating the class ``Reasoner`` | |||
| to minimize inconsistencies between the knowledge base and pseudo | |||
| labels predicted by the learning part. | |||
| .. code:: python | |||
| from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner | |||
| Building a knowledge base | |||
| ------------------------- | |||
| Generally, we can create a subclass derived from ``KBBase`` to build our own | |||
| knowledge base. In addition, ABL-Package also offers several predefined | |||
| subclasses of ``KBBase`` (e.g., ``PrologKB`` and ``GroundKB``), | |||
| which we can utilize to build our knowledge base more conveniently. | |||
| Building a knowledge base from ``KBBase`` | |||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
| For the user-built KB from ``KBBase`` (a derived subclass), it's only | |||
| required to pass the ``pseudo_label_list`` parameter in the ``__init__`` function | |||
| and override the ``logic_forward`` function: | |||
| - ``pseudo_label_list`` is the list of possible pseudo-labels (also, | |||
| the output of the machine learning model). | |||
| - ``logic_forward`` defines how to perform (deductive) reasoning, | |||
| i.e. matching each pseudo-labels to its reasoning result. | |||
| .. note:: | |||
| Generally, the overridden function ``logic_forward`` provided by the user accepts | |||
| only one parameter, ``pseudo_label`` (pseudo-labels of an example). However, for certain | |||
| scenarios, deductive reasoning in the knowledge base may necessitate information | |||
| from the input. In these scenarios, ``logic_forward`` can also accept two parameters: | |||
| ``pseudo_label`` and ``x``. See examples in `Zoo <../Examples/Zoo.html>`_. | |||
| After that, other operations, including how to perform abductive | |||
| reasoning, will be **automatically** set up. | |||
| MNIST Addition example | |||
| ^^^^^^^^^^^^^^^^^^^^^^ | |||
| As an example, the ``pseudo_label_list`` passed in MNIST Addition is all the | |||
| possible digits, namely, ``[0,1,2,...,9]``, and the ``logic_forward`` | |||
| should be: “Add the two pseudo-labels to get the result.”. Therefore, the | |||
| construction of the KB (``add_kb``) for MNIST Addition would be: | |||
| .. code:: python | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| def logic_forward(self, pseudo_labels): | |||
| return sum(pseudo_labels) | |||
| add_kb = AddKB() | |||
| and (deductive) reasoning in ``add_kb`` would be: | |||
| .. code:: python | |||
| pseudo_labels = [1, 2] | |||
| reasoning_result = add_kb.logic_forward(pseudo_labels) | |||
| print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| Reasoning result of pseudo-labels [1, 2] is 3 | |||
| .. _other-par: | |||
| Other optional parameters | |||
| ^^^^^^^^^^^^^^^^^^^^^^^^^ | |||
| We can also pass the following parameters in the ``__init__`` function when building our | |||
| knowledge base: | |||
| - ``max_err`` (float, optional), specifying the upper tolerance limit | |||
| when comparing the similarity between the reasoning result of pseudo-labels | |||
| and the ground truth during abductive reasoning. This is only | |||
| applicable when the reasoning result is of a numerical type. This is | |||
| particularly relevant for regression problems where exact matches | |||
| might not be feasible. Defaults to 1e-10. See :ref:`an example <kb-abd-2>`. | |||
| - ``use_cache`` (bool, optional), indicating whether to use cache to store | |||
| previous candidates (pseudo-labels generated from abductive reasoning) | |||
| to speed up subsequent abductive reasoning operations. Defaults to True. | |||
| For more information of abductive reasoning, please refer to :ref:`this <kb-abd>`. | |||
| - ``cache_size`` (int, optional), specifying the maximum cache | |||
| size. This is only operational when ``use_cache`` is set to True. | |||
| Defaults to 4096. | |||
| .. _prolog: | |||
| Building a knowledge base from Prolog file | |||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
| When aiming to leverage knowledge base from an external Prolog file | |||
| (which contains how to perform reasoning), we can directly create an | |||
| instance of class ``PrologKB``. Upon instantiation of | |||
| ``PrologKB``, we are required to pass the ``pseudo_label_list`` (same as ``KBBase``) | |||
| and ``pl_file`` (the Prolog file) in the ``__init__`` function. | |||
| .. admonition:: What is a Prolog file? | |||
| A Prolog file (typically have the extension ``.pl``) is a script or source | |||
| code file written in the Prolog language. Prolog is a logic programming language | |||
| where the logic is represented as facts | |||
| (basic assertions about some world) and | |||
| rules (logical statements that describe the relationships between facts). | |||
| A computation is initiated by running a query over these facts and rules. | |||
| See some Prolog examples | |||
| in `SWISH <https://swish.swi-prolog.org/>`_. | |||
| After the instantiation, other operations, including how to perform | |||
| abductive reasoning, will also be **automatically** set up. | |||
| .. warning:: | |||
| Note that to use the default logic forward and abductive reasoning | |||
| methods in this class, the Prolog (.pl) file should contain a rule | |||
| with a strict format: ``logic_forward(Pseudo_labels, Res).`` | |||
| Otherwise, we might have to override ``logic_forward`` and | |||
| ``get_query_string`` to allow for more adaptable usage. | |||
| MNIST Addition example (cont.) | |||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |||
| As an example, we can first write a Prolog file for the MNIST Addition | |||
| example as the following code, and then save it as ``add.pl``. | |||
| .. code:: prolog | |||
| pseudo_label(N) :- between(0, 9, N). | |||
| logic_forward([Z1, Z2], Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2. | |||
| Afterwards, the construction of knowledge base from Prolog file | |||
| (``add_prolog_kb``) would be as follows: | |||
| .. code:: python | |||
| add_prolog_kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
| Building a knowledge base with GKB from ``GroundKB`` | |||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
| We can also inherit from class ``GroundKB`` to build our own | |||
| knowledge base. In this way, the knowledge built will have a Ground KB | |||
| (GKB). | |||
| .. admonition:: What is Ground KB? | |||
| `Ground KB <https://www.ijcai.org/proceedings/2021/250>`_ is a knowledge base prebuilt upon class initialization, | |||
| storing all potential candidates along with their respective reasoning | |||
| result. The key advantage of having a Ground KB is that it may | |||
| accelerate abductive reasoning. | |||
| ``GroundKB`` is a subclass of ``GKBBase``. Similar to ``KBBase``, we | |||
| are required to pass the ``pseudo_label_list`` parameter in the ``__init__`` function and | |||
| override the ``logic_forward`` function, and are allowed to pass other | |||
| :ref:`optional parameters <other-par>`. Additionally, we are required pass the | |||
| ``GKB_len_list`` parameter in the ``__init__`` function. | |||
| - ``GKB_len_list`` is the list of possible lengths for pseudo-labels of an example. | |||
| After that, other operations, including auto-construction of GKB, and | |||
| how to perform abductive reasoning, will be **automatically** set up. | |||
| MNIST Addition example (cont.) | |||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |||
| As an example, the ``GKB_len_list`` for MNIST Addition should be ``[2]``, | |||
| since all pseudo-labels in the example consist of two digits. Therefore, | |||
| the construction of KB with GKB (``add_ground_kb``) of MNIST Addition would be | |||
| as follows. As mentioned, the difference between this and the previously | |||
| built ``add_kb`` lies only in the base class from which it is derived | |||
| and whether an extra parameter ``GKB_len_list`` is passed. | |||
| .. code:: python | |||
| class AddGroundKB(GroundKB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), | |||
| GKB_len_list=[2]): | |||
| super().__init__(pseudo_label_list, GKB_len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| add_ground_kb = AddGroundKB() | |||
| .. _kb-abd: | |||
| Performing abductive reasoning in the knowledge base | |||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
| As mentioned in :ref:`What is Abductive Reasoning? <abd>`, abductive reasoning | |||
| enables the inference of candidates (i.e., possible pseudo-labels) as potential | |||
| explanations for the reasoning result. Also, in Abductive Learning where | |||
| an observation (pseudo-labels of an example predicted by the learning part) is | |||
| available, we aim to let the candidate do not largely revise the | |||
| previously identified pseudo-labels. | |||
| ``KBBase`` (also, ``GroundKB`` and ``PrologKB``) implement the method | |||
| ``abduce_candidates(pseudo_label, y, x, max_revision_num, require_more_revision)`` | |||
| for performing abductive reasoning, where the parameters are: | |||
| - ``pseudo_label``, pseudo-labels of an example, usually generated by the learning | |||
| part. They are to be revised by abductive reasoning. | |||
| - ``y``, the ground truth of the reasoning result for the example. The | |||
| returned candidates should be compatible with it. | |||
| - ``x``, the corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| - ``max_revision_num``, an int value specifying the upper limit on the | |||
| number of revised labels for each example. | |||
| - ``require_more_revision``, an int value specifying additional number | |||
| of revisions permitted beyond the minimum required. (e.g., If we set | |||
| it to 0, even if ``max_revision_num`` is set to a high value, the | |||
| method will only output candidates with the minimum possible | |||
| revisions.) | |||
| And it return a list of candidates (i.e., revised pseudo-labels of the example) | |||
| that are all compatible with ``y``. | |||
| MNIST Addition example (cont.) | |||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |||
| As an example, with MNIST Addition, the candidates returned by | |||
| ``add_kb.abduce_candidates`` would be as follows: | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| | ``pseudo_label`` | ``y`` | ``max_revision_num`` | ``require_more_address`` | Output | | |||
| +==================+=======+======================+==========================+================+ | |||
| | [1,1] | 8 | 1 | 0 | [[1,7], [7,1]] | | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| | [1,1] | 8 | 1 | 1 | [[1,7], [7,1]] | | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| | [1,1] | 8 | 2 | 0 | [[1,7], [7,1]] | | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| | [1,1] | 8 | 2 | 1 | [[1,7], | | |||
| | | | | | [7,1], [2,6], | | |||
| | | | | | [6,2], [3,5], | | |||
| | | | | | [5,3], [4,4]] | | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| | [1,1] | 11 | 1 | 0 | [] | | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| .. _kb-abd-2: | |||
| As another example, if we set the ``max_err`` of ``AddKB`` to be 1 | |||
| instead of the default 1e-10, the tolerance limit for consistency will | |||
| be higher, hence the candidates returned would be: | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| | ``pseudo_label`` | ``y`` | ``max_revision_num`` | ``require_more_address`` | Output | | |||
| +==================+=======+======================+==========================+================+ | |||
| | [1,1] | 8 | 1 | 0 | [[1,7], [7,1], | | |||
| | | | | | [1,6], [6,1], | | |||
| | | | | | [1,8], [8,1]] | | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| | [1,1] | 11 | 1 | 0 | [[1,9], [9,1]] | | |||
| +------------------+-------+----------------------+--------------------------+----------------+ | |||
| Creating a reasoner | |||
| ------------------- | |||
| After building our knowledge base, the next step is creating a | |||
| reasoner. Due to the indeterminism of abductive reasoning, there could | |||
| be multiple candidates compatible to the knowledge base. When this | |||
| happens, reasoner can minimize inconsistencies between the knowledge | |||
| base and pseudo-labels predicted by the learning part, and then return **only | |||
| one** candidate that has the highest consistency. | |||
| We can create a reasoner simply by instantiating class | |||
| ``Reasoner`` and passing our knowledge base as an parameter. As an | |||
| example for MNIST Addition, the reasoner definition would be: | |||
| .. code:: python | |||
| reasoner_add = Reasoner(kb_add) | |||
| When instantiating, besides the required knowledge base, we may also | |||
| specify: | |||
| - ``max_revision`` (int or float, optional), specifies the upper limit | |||
| on the number of revisions for each example when performing | |||
| :ref:`abductive reasoning in the knowledge base <kb-abd>`. If float, denotes the | |||
| fraction of the total length that can be revised. A value of -1 | |||
| implies no restriction on the number of revisions. Defaults to -1. | |||
| - ``require_more_revision`` (int, optional), Specifies additional | |||
| number of revisions permitted beyond the minimum required when | |||
| performing :ref:`abductive reasoning in the knowledge base <kb-abd>`. Defaults to | |||
| 0. | |||
| - ``use_zoopt`` (bool, optional), indicating whether to use the `ZOOpt library <https://github.com/polixir/ZOOpt>`_, | |||
| which is a library for zeroth-order optimization that can be used to | |||
| accelerate consistency minimization. Defaults to False. | |||
| - ``dist_func`` (str, optional), specifying the distance function to be | |||
| used when determining consistency between your prediction and | |||
| candidate returned from knowledge base. Valid options include | |||
| “confidence” (default) and “hamming”. For “confidence”, it calculates | |||
| the distance between the prediction and candidate based on confidence | |||
| derived from the predicted probability in the data example. For | |||
| “hamming”, it directly calculates the Hamming distance between the | |||
| predicted pseudo-label in the data example and candidate. | |||
| - ``idx_to_label`` (dict, optional), a mapping from index in the base model to label. | |||
| If not provided, a default order-based index to label mapping is created. | |||
| Defaults to None. | |||
| The main method implemented by ``Reasoner`` is | |||
| ``abduce(data_example)``, which obtains the most consistent candidate | |||
| based on the distance function defined in ``dist_func``. | |||
| MNIST Addition example (cont.) | |||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
| As an example, consider these data examples for MNIST Addition: | |||
| .. code:: python | |||
| # favor "1" for the first label | |||
| prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| # favor "7" for the first label | |||
| prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| example1 = ListData() | |||
| example1.pred_pseudo_label = [1, 1] | |||
| example1.pred_prob = prob1 | |||
| example1.Y = 8 | |||
| example2 = ListData() | |||
| example2.pred_pseudo_label = [1, 1] | |||
| example2.pred_prob = prob2 | |||
| example2.Y = 8 | |||
| The compatible candidates after abductive reasoning for both examples | |||
| would be ``[[1,7], [7,1]]``. However, when the reasoner call ``abduce`` | |||
| to select only one candidate based on the ``confidence`` distance function, | |||
| the output would differ for each example: | |||
| .. code:: python | |||
| reasoner_add = Reasoner(kb_add, dist_func="confidence") | |||
| candidate1 = reasoner_add.abduce(example1) | |||
| candidate2 = reasoner_add.abduce(example2) | |||
| print(f"The outputs for example1 and example2 are {candidate1} and {candidate2}, respectively.") | |||
| Out: | |||
| .. code:: none | |||
| :class: code-out | |||
| The outputs for example1 and example2 are [1,7] and [7,1], respectively. | |||
| Specifically, as mentioned before, ``confidence`` calculates the distance between the data | |||
| example and candidates based on the confidence derived from the predicted probability. | |||
| Take ``example1`` as an example, the ``pred_prob`` in it indicates a higher | |||
| confidence that the first label should be "1" rather than "7". Therefore, among the | |||
| candidates [1,7] and [7,1], it would be closer to [1,7] (as its first label is "1"). | |||
| @@ -0,0 +1,20 @@ | |||
| # Minimal makefile for Sphinx documentation | |||
| # | |||
| # You can set these variables from the command line. | |||
| SPHINXOPTS = | |||
| SPHINXBUILD = sphinx-build | |||
| SPHINXPROJ = ABL-Package | |||
| SOURCEDIR = . | |||
| BUILDDIR = build | |||
| # Put it first so that "make" without argument is like "make help". | |||
| help: | |||
| @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | |||
| .PHONY: help Makefile | |||
| # Catch-all target: route all unknown targets to Sphinx using the new | |||
| # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). | |||
| %: Makefile | |||
| @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | |||
| @@ -0,0 +1,80 @@ | |||
| Abductive Learning | |||
| ================== | |||
| Traditional supervised machine learning, e.g. classification, is | |||
| predominantly data-driven, as shown in the below figure. | |||
| Here, a set of data examples is given, including training instances | |||
| :math:`\{x_1,\dots,x_m\}` and corresponding ground-truth labels :math:`\{\text{label}(x_1),\dots,\text{label}(x_m)\}`. | |||
| These data are then used to train a classifier model :math:`f`, | |||
| aiming to accurately predict the unseen data instances. | |||
| .. image:: ../_static/img/ML.png | |||
| :align: center | |||
| :width: 280px | |||
| In **Abductive Learning (ABL)**, we assume that, in addition to data, | |||
| there is also a knowledge base :math:`\mathcal{KB}` containing | |||
| domain knowledge at our disposal. We aim for the classifier :math:`f` | |||
| to make correct predictions on data instances :math:`\{x_1,\dots,x_m\}`, | |||
| and meanwhile, the pseudo-groundings grounded by the prediction | |||
| :math:`\left\{f(\boldsymbol{x}_1), \ldots, f(\boldsymbol{x}_m)\right\}` | |||
| should be compatible with :math:`\mathcal{KB}`. | |||
| The process of ABL is as follows: | |||
| 1. Upon receiving data instances :math:`\left\{x_1,\dots,x_m\right\}` as input, | |||
| pseudo-labels | |||
| :math:`\left\{f(\boldsymbol{x}_1), \ldots, f(\boldsymbol{x}_m)\right\}` | |||
| are predicted by a data-driven classifier model. | |||
| 2. These pseudo-labels are then converted into pseudo-groundings | |||
| :math:`\mathcal{O}` that are acceptable for logical reasoning. | |||
| 3. Conduct joint reasoning with :math:`\mathcal{KB}` to find any | |||
| inconsistencies. If found, the pseudo-groundings that lead to minimal | |||
| inconsistency can be identified. | |||
| 4. Modify the identified facts through **abductive reasoning** (or, **abduction**), | |||
| returning revised pseudo-groundings :math:`\Delta(\mathcal{O})` which are | |||
| compatible with :math:`\mathcal{KB}`. | |||
| 5. These revised pseudo-groundings are converted back to the form of | |||
| pseudo-labels, and used like ground-truth labels in conventional | |||
| supervised learning to train a new classifier. | |||
| 6. The new classifier will then be adopted to replace the previous one | |||
| in the next iteration. | |||
| This above process repeats until the classifier is no longer updated, or | |||
| the pseudo-groundings :math:`\mathcal{O}` are compatible with the knowledge | |||
| base. | |||
| The following figure illustrates this process: | |||
| .. image:: ../_static/img/ABL.png | |||
| :width: 800px | |||
| We can observe that in the above figure, the left half involves machine | |||
| learning, while the right half involves logical reasoning. Thus, the | |||
| entire Abductive Learning process is a continuous cycle of machine | |||
| learning and logical reasoning. This effectively forms a paradigm that | |||
| is dual-driven by both data and domain knowledge, integrating and | |||
| balancing the use of machine learning and logical reasoning in a unified | |||
| model. | |||
| For more information about ABL, please refer to: `Zhou, 2019 <https://link.springer.com/epdf/10.1007/s11432-018-9801-4?author_access_token=jgJe1Ox3Mk-K7ORSnX7jtfe4RwlQNchNByi7wbcMAY7_PxTx-xNLP7Lp0mIZ04ORp3VG4wioIBHSCIAO3B_TBJkj87YzapmdnYVSQvgBIO3aEpQWppxZG25KolINetygc2W_Cj2gtoBdiG_J1hU3pA==>`_ | |||
| and `Zhou and Huang, 2022 <https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf>`_. | |||
| .. _abd: | |||
| .. admonition:: What is Abductive Reasoning? | |||
| Abductive reasoning, also known as abduction, refers to the process of | |||
| selectively inferring certain facts and hypotheses that explain | |||
| phenomena and observations based on background knowledge. Unlike | |||
| deductive reasoning, which leads to definitive conclusions, abductive | |||
| reasoning may arrive at conclusions that are plausible but not conclusively | |||
| proven. | |||
| In ABL, given :math:`\mathcal{KB}` (typically expressed | |||
| in first-order logic clauses), one can perform both deductive and | |||
| abductive reasoning. Deductive reasoning allows deriving | |||
| :math:`b` from :math:`a`, while abductive reasoning allows inferring | |||
| :math:`a` as an explanation of :math:`b`. In other words, | |||
| deductive reasoning and abductive reasoning differ in which end, | |||
| right or left, of the proposition “:math:`a\models b`” serves as conclusion. | |||
| @@ -0,0 +1,34 @@ | |||
| Installation | |||
| ================== | |||
| ABL is distributed on `PyPI <https://pypi.org/>`__ and can be installed with ``pip``: | |||
| .. code:: console | |||
| # (TODO) | |||
| $ pip install abl | |||
| For testing purposes, you can install it using: | |||
| .. code:: console | |||
| $ pip install -i https://test.pypi.org/simple/ --extra-index-url https://mirrors.nju.edu.cn/pypi/web/simple/ abl | |||
| Alternatively, to install ABL by source code, | |||
| sequentially run following commands in your terminal/command line. | |||
| .. code:: console | |||
| $ git clone https://github.com/AbductiveLearning/ABL-Package.git | |||
| $ cd ABL-Package | |||
| $ pip install -v -e . | |||
| (Optional) If the use of a :ref:`Prolog-based knowledge base <prolog>` is necessary, the installation of `Swi-Prolog <https://www.swi-prolog.org/>`_ is also required: | |||
| For Linux users: | |||
| .. code:: console | |||
| $ sudo apt-get install swi-prolog | |||
| For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_. | |||
| @@ -0,0 +1,60 @@ | |||
| ABL-Package | |||
| =========== | |||
| **ABL-Package** is an open source library for **Abductive Learning (ABL)**. | |||
| ABL is a novel paradigm that integrates machine learning and | |||
| logical reasoning in a unified framework. It is suitable for tasks | |||
| where both data and (logical) domain knowledge are available. | |||
| Key Features of ABL-Package: | |||
| - **Great Flexibility**: Adaptable to various machine learning modules and logical reasoning components. | |||
| - **User-Friendly**: Provide data, model, and KB, and get started with just a few lines of code. | |||
| - **High-Performance**: Optimization for high accuracy and fast training speed. | |||
| ABL-Package encapsulates advanced ABL techniques, providing users with | |||
| an efficient and convenient package to develop dual-driven ABL systems, | |||
| which leverage the power of both data and knowledge. | |||
| .. image:: _static/img/ABL.png | |||
| Installation | |||
| ------------ | |||
| ABL is distributed on `PyPI <https://pypi.org/>`__ and can be installed with ``pip``: | |||
| .. code:: console | |||
| # (TODO) | |||
| $ pip install abl | |||
| For testing purposes, you can install it using: | |||
| .. code:: console | |||
| $ pip install -i https://test.pypi.org/simple/ --extra-index-url https://mirrors.nju.edu.cn/pypi/web/simple/ abl | |||
| Alternatively, to install ABL by source code, | |||
| sequentially run following commands in your terminal/command line. | |||
| .. code:: console | |||
| $ git clone https://github.com/AbductiveLearning/ABL-Package.git | |||
| $ cd ABL-Package | |||
| $ pip install -v -e . | |||
| (Optional) If the use of a :ref:`Prolog-based knowledge base <prolog>` is necessary, the installation of `Swi-Prolog <https://www.swi-prolog.org/>`_ is also required: | |||
| For Linux users: | |||
| .. code:: console | |||
| $ sudo apt-get install swi-prolog | |||
| For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_. | |||
| References | |||
| ---------- | |||
| For more information about ABL, please refer to: `Zhou, 2019 <https://link.springer.com/epdf/10.1007/s11432-018-9801-4?author_access_token=jgJe1Ox3Mk-K7ORSnX7jtfe4RwlQNchNByi7wbcMAY7_PxTx-xNLP7Lp0mIZ04ORp3VG4wioIBHSCIAO3B_TBJkj87YzapmdnYVSQvgBIO3aEpQWppxZG25KolINetygc2W_Cj2gtoBdiG_J1hU3pA==>`_ | |||
| and `Zhou and Huang, 2022 <https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf>`_. | |||
| @@ -0,0 +1,10 @@ | |||
| References | |||
| ========== | |||
| Zhi-Hua Zhou. `Abductive learning: Towards bridging machine learning and logical reasoning. <https://link.springer.com/epdf/10.1007/s11432-018-9801-4?author_access_token=jgJe1Ox3Mk-K7ORSnX7jtfe4RwlQNchNByi7wbcMAY7_PxTx-xNLP7Lp0mIZ04ORp3VG4wioIBHSCIAO3B_TBJkj87YzapmdnYVSQvgBIO3aEpQWppxZG25KolINetygc2W_Cj2gtoBdiG_J1hU3pA==>`_. **Science China Information Sciences**, 2019, 62: 076101. | |||
| Zhi-Hua Zhou and Yu-Xuan Huang. `Abductive learning <https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf>`_. In P. Hitzler and M. K. Sarker eds., **Neuro-Symbolic Artificial Intelligence: The State of the Art**, IOP Press, Amsterdam, 2022, p.353-379 | |||
| @@ -0,0 +1,24 @@ | |||
| div.code-out > div.highlight > pre { | |||
| background-color: #d3effd !important; | |||
| } | |||
| .green-bold { | |||
| color: green; | |||
| font-weight: bold; | |||
| } | |||
| .blue-bold { | |||
| color: blue; | |||
| font-weight: bold; | |||
| } | |||
| .yellow-bold { | |||
| color: rgb(255, 192, 0); | |||
| font-weight: bold; | |||
| } | |||
| .green { | |||
| color: green; | |||
| } | |||
| .blue { | |||
| color: blue; | |||
| } | |||
| .yellow { | |||
| color: rgb(255, 192, 0); | |||
| } | |||
| @@ -0,0 +1,115 @@ | |||
| import os | |||
| import re | |||
| import sys | |||
| from docutils import nodes | |||
| from docutils.parsers.rst import roles | |||
| from sphinx.application import Sphinx | |||
| def remove_noqa(app: Sphinx, what: str, name: str, obj, options, lines): | |||
| new_lines = [] | |||
| for line in lines: | |||
| new_line = re.sub(r"\s*#\s*noqa.*$", "", line) | |||
| new_lines.append(new_line) | |||
| lines[:] = new_lines | |||
| def colored_text_role(role, rawtext, text, lineno, inliner, options={}, content=[]): | |||
| node = nodes.inline(rawtext, text, classes=[role]) | |||
| return [node], [] | |||
| roles.register_local_role("green-bold", colored_text_role) | |||
| roles.register_local_role("blue-bold", colored_text_role) | |||
| roles.register_local_role("yellow-bold", colored_text_role) | |||
| roles.register_local_role("green", colored_text_role) | |||
| roles.register_local_role("blue", colored_text_role) | |||
| roles.register_local_role("yellow", colored_text_role) | |||
| if "READTHEDOCS" not in os.environ: | |||
| sys.path.insert(0, os.path.abspath("..")) | |||
| sys.path.append(os.path.abspath("./ABL/")) | |||
| project = "ABL" | |||
| slug = re.sub(r"\W+", "-", project.lower()) | |||
| project = "ABL-Package" | |||
| copyright = "LAMDA, 2024" | |||
| author = "Author" | |||
| # -- General configuration --------------------------------------------------- | |||
| # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration | |||
| extensions = [ | |||
| "sphinx.ext.intersphinx", | |||
| "sphinx.ext.autodoc", | |||
| "sphinx.ext.autosummary", | |||
| "sphinx.ext.mathjax", | |||
| "sphinx.ext.viewcode", | |||
| "sphinx_rtd_theme", | |||
| "recommonmark", | |||
| "sphinx_markdown_tables", | |||
| "sphinx.ext.napoleon", | |||
| "sphinx_copybutton", | |||
| ] | |||
| templates_path = ["_templates"] | |||
| source_suffix = [".rst", ".md"] | |||
| exclude_patterns = [] | |||
| # locale_dirs = ['locale/'] | |||
| gettext_compact = False | |||
| master_doc = "index" | |||
| suppress_warnings = ["image.nonlocal_uri"] | |||
| pygments_style = "default" | |||
| # intersphinx_mapping = { | |||
| # 'rtd': ('https://docs.readthedocs.io/en/latest/', None), | |||
| # 'sphinx': ('http://www.sphinx-doc.org/en/stable/', None), | |||
| # } | |||
| html_theme = "sphinx_rtd_theme" | |||
| html_theme_options = {"display_version": True} | |||
| html_static_path = ["_static"] | |||
| html_css_files = ["custom.css"] | |||
| # html_theme_path = ["../.."] | |||
| # html_logo = "demo/static/logo-wordmark-light.svg" | |||
| # html_show_sourcelink = True | |||
| htmlhelp_basename = slug | |||
| # latex_documents = [ | |||
| # ('index', '{0}.tex'.format(slug), project, author, 'manual'), | |||
| # ] | |||
| man_pages = [("index", slug, project, [author], 1)] | |||
| texinfo_documents = [ | |||
| ("index", slug, project, author, slug, project, "Miscellaneous"), | |||
| ] | |||
| # Extensions to theme docs | |||
| def setup(app): | |||
| from sphinx.domains.python import PyField | |||
| from sphinx.util.docfields import Field | |||
| app.connect("autodoc-process-docstring", remove_noqa) | |||
| app.add_object_type( | |||
| "confval", | |||
| "confval", | |||
| objname="configuration value", | |||
| indextemplate="pair: %s; configuration value", | |||
| doc_field_types=[ | |||
| PyField("type", label=("Type"), has_arg=False, names=("type",), bodyrolename="class"), | |||
| Field( | |||
| "default", | |||
| label=("Default"), | |||
| has_arg=False, | |||
| names=("default",), | |||
| ), | |||
| ], | |||
| ) | |||
| @@ -0,0 +1,47 @@ | |||
| .. include:: README.rst | |||
| .. toctree:: | |||
| :maxdepth: 1 | |||
| :caption: Overview | |||
| Overview/Abductive-Learning | |||
| Overview/Installation | |||
| .. toctree:: | |||
| :maxdepth: 1 | |||
| :caption: Introduction to ABL-Package | |||
| Intro/Basics | |||
| Intro/Quick-Start | |||
| Intro/Datasets | |||
| Intro/Learning | |||
| Intro/Reasoning | |||
| Intro/Evaluation | |||
| Intro/Bridge | |||
| .. toctree:: | |||
| :maxdepth: 1 | |||
| :caption: Examples | |||
| Examples/MNISTAdd | |||
| Examples/HWF | |||
| Examples/HED | |||
| Examples/Zoo | |||
| .. toctree:: | |||
| :maxdepth: 1 | |||
| :caption: API | |||
| API/abl.data | |||
| API/abl.learning | |||
| API/abl.reasoning | |||
| API/abl.bridge | |||
| API/abl.utils | |||
| .. toctree:: | |||
| :maxdepth: 1 | |||
| :caption: References | |||
| References | |||
| @@ -0,0 +1,38 @@ | |||
| @ECHO OFF | |||
| pushd %~dp0 | |||
| REM Command file for Sphinx documentation | |||
| if "%SPHINXBUILD%" == "" ( | |||
| set SPHINXBUILD=python -msphinx | |||
| ) | |||
| set SPHINXOPTS= | |||
| set SPHINXBUILD=sphinx-build | |||
| set SOURCEDIR=. | |||
| set BUILDDIR=build | |||
| set SPHINXPROJ=ReadtheDocsSphinxTheme | |||
| if "%1" == "" goto help | |||
| %SPHINXBUILD% >NUL 2>NUL | |||
| if errorlevel 9009 ( | |||
| echo. | |||
| echo.The Sphinx module was not found. Make sure you have Sphinx installed, | |||
| echo.then set the SPHINXBUILD environment variable to point to the full | |||
| echo.path of the 'sphinx-build' executable. Alternatively you may add the | |||
| echo.Sphinx directory to PATH. | |||
| echo. | |||
| echo.If you don't have Sphinx installed, grab it from | |||
| echo.http://sphinx-doc.org/ | |||
| exit /b 1 | |||
| ) | |||
| %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% | |||
| goto end | |||
| :help | |||
| %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% | |||
| :end | |||
| popd | |||
| @@ -0,0 +1,5 @@ | |||
| sphinx | |||
| sphinx-rtd-theme | |||
| recommonmark | |||
| sphinx-markdown-tables | |||
| sphinx-copybutton | |||
| @@ -0,0 +1,39 @@ | |||
| # Handwritten Equation Decipherment | |||
| This notebook shows an implementation of [Handwritten Equation Decipherment](https://proceedings.neurips.cc/paper_files/paper/2019/file/9c19a2aa1d84e04b0bd4bc888792bd1e-Paper.pdf). In this task, the handwritten equations are given, which consist of sequential pictures of characters. The equations are generated with unknown operation rules from images of symbols ('0', '1', '+' and '='), and each equation is associated with a label indicating whether the equation is correct (i.e., positive) or not (i.e., negative). Also, we are given a knowledge base which involves the structure of the equations and a recursive definition of bit-wise operations. The task is to learn from a training set of above mentioned equations and then to predict labels of unseen equations. | |||
| ## Run | |||
| ```bash | |||
| pip install -r requirements.txt | |||
| python main.py | |||
| ``` | |||
| ## Usage | |||
| ```bash | |||
| usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] | |||
| [--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE] | |||
| [--loops LOOPS] [--segment_size SEGMENT_SIZE] | |||
| [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION] | |||
| [--require-more-revision REQUIRE_MORE_REVISION] | |||
| [--ground] [--max-err MAX_ERR] | |||
| Handwritten Equation Decipherment example | |||
| optional arguments: | |||
| -h, --help show this help message and exit | |||
| --no-cuda disables CUDA training | |||
| --epochs EPOCHS number of epochs in each learning loop iteration | |||
| (default : 1) | |||
| --lr LR base model learning rate (default : 0.001) | |||
| --weight-decay WEIGHT_DECAY | |||
| weight decay (default : 0.0001) | |||
| --batch-size BATCH_SIZE | |||
| base model batch size (default : 32) | |||
| --save_interval SAVE_INTERVAL | |||
| save interval (default : 1) | |||
| --max-revision MAX_REVISION | |||
| maximum revision in reasoner (default : 10) | |||
| ``` | |||
| @@ -0,0 +1,273 @@ | |||
| import os | |||
| from collections import defaultdict | |||
| from typing import Any, List, Optional, Tuple, Union | |||
| import torch | |||
| from abl.bridge import SimpleBridge | |||
| from abl.data.evaluation import BaseMetric | |||
| from abl.data.structures import ListData | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.learning.torch_dataset import RegressionDataset | |||
| from abl.reasoning import Reasoner | |||
| from abl.utils import print_log | |||
| from datasets import get_pretrain_data | |||
| from models.nn import SymbolNetAutoencoder | |||
| from utils import InfiniteSampler, gen_mappings | |||
| class HedBridge(SimpleBridge): | |||
| def __init__( | |||
| self, | |||
| model: ABLModel, | |||
| reasoner: Reasoner, | |||
| metric_list: BaseMetric, | |||
| ) -> None: | |||
| super().__init__(model, reasoner, metric_list) | |||
| def pretrain(self, weights_dir): | |||
| if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")): | |||
| print_log("Pretrain Start", logger="current") | |||
| cls_autoencoder = SymbolNetAutoencoder( | |||
| num_classes=len(self.reasoner.kb.pseudo_label_list) | |||
| ) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| loss_fn = torch.nn.MSELoss() | |||
| optimizer = torch.optim.RMSprop( | |||
| cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 | |||
| ) | |||
| pretrain_model = BasicNN( | |||
| cls_autoencoder, | |||
| loss_fn, | |||
| optimizer, | |||
| device=device, | |||
| save_interval=1, | |||
| save_dir=weights_dir, | |||
| num_epochs=10, | |||
| ) | |||
| pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) | |||
| pretrain_data = RegressionDataset(pretrain_data_X, pretrain_data_Y) | |||
| pretrain_data_loader = torch.utils.data.DataLoader( | |||
| pretrain_data, batch_size=64, shuffle=True | |||
| ) | |||
| pretrain_model.fit(pretrain_data_loader) | |||
| save_parma_dic = { | |||
| "model": cls_autoencoder.base_model.state_dict(), | |||
| } | |||
| torch.save(save_parma_dic, os.path.join(weights_dir, "pretrain_weights.pth")) | |||
| self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth")) | |||
| def select_mapping_and_abduce(self, data_examples: ListData): | |||
| candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) | |||
| mapping_score = [] | |||
| abduced_pseudo_label_list = [] | |||
| for _mapping in candidate_mappings: | |||
| self.reasoner.idx_to_label = _mapping | |||
| self.reasoner.label_to_idx = dict(zip(_mapping.values(), _mapping.keys())) | |||
| self.idx_to_pseudo_label(data_examples) | |||
| abduced_pseudo_label = self.reasoner.abduce(data_examples) | |||
| mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([])) | |||
| abduced_pseudo_label_list.append(abduced_pseudo_label) | |||
| max_revisible_instances = max(mapping_score) | |||
| return_idx = mapping_score.index(max_revisible_instances) | |||
| self.reasoner.idx_to_label = candidate_mappings[return_idx] | |||
| self.reasoner.label_to_idx = dict( | |||
| zip(self.reasoner.idx_to_label.values(), self.reasoner.idx_to_label.keys()) | |||
| ) | |||
| self.idx_to_pseudo_label(data_examples) | |||
| data_examples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx] | |||
| return data_examples.abduced_pseudo_label | |||
| def abduce_pseudo_label(self, data_examples: ListData): | |||
| self.reasoner.abduce(data_examples) | |||
| return data_examples.abduced_pseudo_label | |||
| def check_training_impact(self, filtered_data_examples, data_examples): | |||
| character_accuracy = self.model.valid(filtered_data_examples) | |||
| revisible_ratio = len(filtered_data_examples.X) / len(data_examples.X) | |||
| log_string = ( | |||
| f"Revisible ratio is {revisible_ratio:.3f}, Character " | |||
| f"accuracy is {character_accuracy:.3f}" | |||
| ) | |||
| print_log(log_string, logger="current") | |||
| if character_accuracy >= 0.95 and revisible_ratio >= 0.95: | |||
| return True | |||
| return False | |||
| def check_rule_quality(self, rule, val_data, equation_len): | |||
| val_X_true = self.data_preprocess(val_data[1], equation_len) | |||
| val_X_false = self.data_preprocess(val_data[0], equation_len) | |||
| true_ratio = self.calc_consistent_ratio(val_X_true, rule) | |||
| false_ratio = self.calc_consistent_ratio(val_X_false, rule) | |||
| log_string = ( | |||
| f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio " | |||
| f"is {1 - false_ratio:.3f}" | |||
| ) | |||
| print_log(log_string, logger="current") | |||
| if true_ratio > 0.9 and false_ratio < 0.05: | |||
| return True | |||
| return False | |||
| def calc_consistent_ratio(self, data_examples, rule): | |||
| self.predict(data_examples) | |||
| pred_pseudo_label = self.idx_to_pseudo_label(data_examples) | |||
| consistent_num = sum( | |||
| [self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label] | |||
| ) | |||
| return consistent_num / len(data_examples.X) | |||
| def get_rules_from_data(self, data_examples, samples_per_rule, samples_num): | |||
| rules = [] | |||
| sampler = InfiniteSampler(len(data_examples), batch_size=samples_per_rule) | |||
| for _ in range(samples_num): | |||
| for select_idx in sampler: | |||
| sub_data_examples = data_examples[select_idx] | |||
| self.predict(sub_data_examples) | |||
| pred_pseudo_label = self.idx_to_pseudo_label(sub_data_examples) | |||
| consistent_instance = [] | |||
| for instance in pred_pseudo_label: | |||
| if self.reasoner.kb.logic_forward([instance]): | |||
| consistent_instance.append(instance) | |||
| if len(consistent_instance) != 0: | |||
| rule = self.reasoner.abduce_rules(consistent_instance) | |||
| if rule is not None: | |||
| rules.append(rule) | |||
| break | |||
| all_rule_dict = defaultdict(int) | |||
| for rule in rules: | |||
| for r in rule: | |||
| all_rule_dict[r] += 1 | |||
| rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5} | |||
| rules = self.select_rules(rule_dict) | |||
| return rules | |||
| @staticmethod | |||
| def filter_empty(data_examples: ListData): | |||
| consistent_dix = [ | |||
| i | |||
| for i in range(len(data_examples.abduced_pseudo_label)) | |||
| if len(data_examples.abduced_pseudo_label[i]) > 0 | |||
| ] | |||
| return data_examples[consistent_dix] | |||
| @staticmethod | |||
| def select_rules(rule_dict): | |||
| add_nums_dict = {} | |||
| for r in list(rule_dict): | |||
| add_nums = str(r.split("]")[0].split("[")[1]) + str( | |||
| r.split("]")[1].split("[")[1] | |||
| ) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' | |||
| if add_nums in add_nums_dict: | |||
| old_r = add_nums_dict[add_nums] | |||
| if rule_dict[r] >= rule_dict[old_r]: | |||
| rule_dict.pop(old_r) | |||
| add_nums_dict[add_nums] = r | |||
| else: | |||
| rule_dict.pop(r) | |||
| else: | |||
| add_nums_dict[add_nums] = r | |||
| return list(rule_dict) | |||
| def data_preprocess(self, data, equation_len) -> ListData: | |||
| data_examples = ListData() | |||
| data_examples.X = data[equation_len] + data[equation_len + 1] | |||
| data_examples.gt_pseudo_label = None | |||
| data_examples.Y = [None] * len(data_examples.X) | |||
| return data_examples | |||
| def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8, save_dir="./"): | |||
| for equation_len in range(min_len, max_len): | |||
| print_log( | |||
| f"============== equation_len: {equation_len}-{equation_len + 1} ================", | |||
| logger="current", | |||
| ) | |||
| condition_num = 0 | |||
| data_examples = self.data_preprocess(train_data[1], equation_len) | |||
| sampler = InfiniteSampler(len(data_examples), batch_size=segment_size) | |||
| for seg_idx, select_idx in enumerate(sampler): | |||
| print_log( | |||
| f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}]", | |||
| logger="current", | |||
| ) | |||
| sub_data_examples = data_examples[select_idx] | |||
| self.predict(sub_data_examples) | |||
| if equation_len == min_len: | |||
| self.select_mapping_and_abduce(sub_data_examples) | |||
| else: | |||
| self.idx_to_pseudo_label(sub_data_examples) | |||
| self.abduce_pseudo_label(sub_data_examples) | |||
| filtered_sub_data_examples = self.filter_empty(sub_data_examples) | |||
| self.pseudo_label_to_idx(filtered_sub_data_examples) | |||
| self.model.train(filtered_sub_data_examples) | |||
| if self.check_training_impact(filtered_sub_data_examples, sub_data_examples): | |||
| condition_num += 1 | |||
| else: | |||
| condition_num = 0 | |||
| if condition_num >= 5: | |||
| print_log("Now checking if we can go to next course", logger="current") | |||
| rules = self.get_rules_from_data( | |||
| data_examples, samples_per_rule=3, samples_num=50 | |||
| ) | |||
| print_log("Learned rules from data: " + str(rules), logger="current") | |||
| seems_good = self.check_rule_quality(rules, val_data, equation_len) | |||
| if seems_good: | |||
| self.reasoner.kb.learned_rules.update( | |||
| {equation_len: rules, equation_len + 1: rules} | |||
| ) | |||
| self.model.save( | |||
| save_path=os.path.join(save_dir, f"eq_len_{equation_len}.pth") | |||
| ) | |||
| break | |||
| else: | |||
| if equation_len == min_len: | |||
| print_log( | |||
| "Learned mapping is: " + str(self.reasoner.idx_to_label), | |||
| logger="current", | |||
| ) | |||
| self.model.load( | |||
| load_path=os.path.join(save_dir, "pretrain_weights.pth") | |||
| ) | |||
| else: | |||
| self.model.load( | |||
| load_path=os.path.join(save_dir, f"eq_len_{equation_len - 1}.pth") | |||
| ) | |||
| condition_num = 0 | |||
| print_log("Reload Model and retrain", logger="current") | |||
| def test( | |||
| self, | |||
| test_data: Union[ | |||
| ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] | |||
| ], | |||
| min_len=5, | |||
| max_len=8, | |||
| ) -> None: | |||
| for equation_len in range(min_len, max_len): | |||
| test_data_examples = self.data_preprocess(test_data[1], equation_len) | |||
| print_log(f"Test on true equations with length {equation_len}", logger="current") | |||
| self._valid(test_data_examples) | |||
| test_data_examples = self.data_preprocess(test_data[0], equation_len) | |||
| print_log(f"Test on false equations with length {equation_len}", logger="current") | |||
| self._valid(test_data_examples) | |||
| @@ -0,0 +1,28 @@ | |||
| from typing import Optional | |||
| from abl.data.evaluation.base_metric import BaseMetric | |||
| from abl.data.structures import ListData | |||
| from abl.reasoning import KBBase | |||
| class ConsistencyMetric(BaseMetric): | |||
| def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | |||
| super().__init__(prefix) | |||
| self.kb = kb | |||
| def process(self, data_examples: ListData) -> None: | |||
| pred_pseudo_label = data_examples.pred_pseudo_label | |||
| learned_rules = self.kb.learned_rules | |||
| consistent_num = sum( | |||
| [ | |||
| self.kb.consist_rule(instance, learned_rules[len(instance)]) | |||
| for instance in pred_pseudo_label | |||
| ] | |||
| ) | |||
| self.results.append((consistent_num, len(pred_pseudo_label))) | |||
| def compute_metrics(self) -> dict: | |||
| results = self.results | |||
| metrics = dict() | |||
| metrics["consistency"] = sum(t[0] for t in results) / sum(t[1] for t in results) | |||
| return metrics | |||
| @@ -0,0 +1,3 @@ | |||
| from .get_dataset import get_dataset, get_pretrain_data, split_equation | |||
| __all__ = ["get_dataset", "get_pretrain_data", "split_equation"] | |||
| @@ -0,0 +1,123 @@ | |||
| import os | |||
| import os.path as osp | |||
| import pickle | |||
| import random | |||
| import zipfile | |||
| from collections import defaultdict | |||
| from PIL import Image | |||
| import gdown | |||
| import numpy as np | |||
| from torchvision.transforms import transforms | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| def download_and_unzip(url, zip_file_name): | |||
| try: | |||
| gdown.download(url, zip_file_name) | |||
| with zipfile.ZipFile(zip_file_name, "r") as zip_ref: | |||
| zip_ref.extractall(CURRENT_DIR) | |||
| os.remove(zip_file_name) | |||
| except Exception as e: | |||
| if os.path.exists(zip_file_name): | |||
| os.remove(zip_file_name) | |||
| raise Exception( | |||
| f"An error occurred during download or unzip: {e}. Instead, you can download " | |||
| + f"the dataset from {url} and unzip it in 'examples/hed/datasets' folder" | |||
| ) | |||
| def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||
| transform = transforms.Compose([transforms.ToTensor()]) | |||
| X = [] | |||
| img_dir = osp.join(CURRENT_DIR, "mnist_images") | |||
| for label in labels: | |||
| label_path = osp.join(img_dir, label) | |||
| img_path_list = os.listdir(label_path) | |||
| for img_path in img_path_list: | |||
| with Image.open(osp.join(label_path, img_path)) as img: | |||
| img = img.convert("L") | |||
| img = img.resize((image_size[1], image_size[0])) | |||
| img_array = np.array(img, dtype=np.float32) | |||
| normalized_img = (img_array - 127) / 128.0 | |||
| X.append(normalized_img) | |||
| Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] | |||
| X = [transform(img[:, :, np.newaxis]) for img in X] | |||
| return X, Y | |||
| def divide_equations_by_len(equations, labels): | |||
| equations_by_len = {1: defaultdict(list), 0: defaultdict(list)} | |||
| for i, equation in enumerate(equations): | |||
| equations_by_len[labels[i]][len(equation)].append(equation) | |||
| return equations_by_len | |||
| def split_equation(equations_by_len, prop_train, prop_val): | |||
| """ | |||
| Split the equations in each length to training and validation data according to the proportion | |||
| """ | |||
| train_equations_by_len = {1: dict(), 0: dict()} | |||
| val_equations_by_len = {1: dict(), 0: dict()} | |||
| for label in range(2): | |||
| for equation_len, equations in equations_by_len[label].items(): | |||
| random.shuffle(equations) | |||
| train_equations_by_len[label][equation_len] = equations[ | |||
| : len(equations) // (prop_train + prop_val) * prop_train | |||
| ] | |||
| val_equations_by_len[label][equation_len] = equations[ | |||
| len(equations) // (prop_train + prop_val) * prop_train : | |||
| ] | |||
| return train_equations_by_len, val_equations_by_len | |||
| def get_dataset(dataset="mnist", train=True): | |||
| data_dir = CURRENT_DIR + "/mnist_images" | |||
| if not os.path.exists(data_dir): | |||
| print("Dataset not exist, downloading it...") | |||
| url = "https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download" | |||
| download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip")) | |||
| print("Download and extraction complete.") | |||
| if train: | |||
| file = os.path.join(data_dir, "expr_train.json") | |||
| else: | |||
| file = os.path.join(data_dir, "expr_test.json") | |||
| if dataset == "mnist": | |||
| file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | |||
| elif dataset == "random": | |||
| file = osp.join(CURRENT_DIR, "random_equation_data_train_len_26_test_len_26_sys_2_.pk") | |||
| else: | |||
| raise ValueError("Undefined dataset") | |||
| with open(file, "rb") as f: | |||
| img_dataset = pickle.load(f) | |||
| X, Y = [], [] | |||
| if train: | |||
| positive = img_dataset["train:positive"] | |||
| negative = img_dataset["train:negative"] | |||
| else: | |||
| positive = img_dataset["test:positive"] | |||
| negative = img_dataset["test:negative"] | |||
| for equation in positive: | |||
| equation = equation.astype(np.float32) | |||
| img_list = np.vsplit(equation, equation.shape[0]) | |||
| X.append(img_list) | |||
| Y.append(1) | |||
| for equation in negative: | |||
| equation = equation.astype(np.float32) | |||
| img_list = np.vsplit(equation, equation.shape[0]) | |||
| X.append(img_list) | |||
| Y.append(0) | |||
| equations_by_len = divide_equations_by_len(X, Y) | |||
| return equations_by_len | |||
| @@ -0,0 +1,111 @@ | |||
| import argparse | |||
| import os.path as osp | |||
| import torch | |||
| import torch.nn as nn | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.utils import ABLLogger, print_log | |||
| from bridge import HedBridge | |||
| from consistency_metric import ConsistencyMetric | |||
| from datasets import get_dataset, split_equation | |||
| from models.nn import SymbolNet | |||
| from reasoning import HedKB, HedReasoner | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="Handwritten Equation Decipherment example") | |||
| parser.add_argument( | |||
| "--no-cuda", action="store_true", default=False, help="disables CUDA training" | |||
| ) | |||
| parser.add_argument( | |||
| "--epochs", | |||
| type=int, | |||
| default=1, | |||
| help="number of epochs in each learning loop iteration (default : 1)", | |||
| ) | |||
| parser.add_argument( | |||
| "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | |||
| ) | |||
| parser.add_argument( | |||
| "--weight-decay", type=float, default=1e-4, help="weight decay (default : 0.0001)" | |||
| ) | |||
| parser.add_argument( | |||
| "--batch-size", type=int, default=32, help="base model batch size (default : 32)" | |||
| ) | |||
| parser.add_argument( | |||
| "--segment_size", type=int, default=1000, help="segment size (default : 1000)" | |||
| ) | |||
| parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") | |||
| parser.add_argument( | |||
| "--max-revision", | |||
| type=int, | |||
| default=10, | |||
| help="maximum revision in reasoner (default : 10)", | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| total_train_data = get_dataset(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_dataset(train=False) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=4) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |||
| use_cuda = not args.no_cuda and torch.cuda.is_available() | |||
| device = torch.device("cuda" if use_cuda else "cpu") | |||
| # Build BasicNN | |||
| base_model = BasicNN( | |||
| cls, | |||
| loss_fn, | |||
| optimizer, | |||
| device=device, | |||
| batch_size=args.batch_size, | |||
| num_epochs=args.epochs, | |||
| stop_loss=None, | |||
| ) | |||
| # Build ABLModel | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| kb = HedKB() | |||
| # Create reasoner | |||
| reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [ConsistencyMetric(kb=kb)] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = HedBridge(model, reasoner, metric_list) | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| bridge.pretrain(weights_dir) | |||
| bridge.train(train_data, val_data, save_dir=weights_dir) | |||
| bridge.test(test_data) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,49 @@ | |||
| import torch | |||
| from torch import nn | |||
| class SymbolNet(nn.Module): | |||
| def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
| super(SymbolNet, self).__init__() | |||
| self.conv1 = nn.Sequential( | |||
| nn.Conv2d(1, 32, 5, stride=1), | |||
| nn.ReLU(), | |||
| nn.MaxPool2d(kernel_size=2, stride=2), | |||
| nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
| ) | |||
| self.conv2 = nn.Sequential( | |||
| nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
| nn.ReLU(), | |||
| nn.MaxPool2d(kernel_size=2, stride=2), | |||
| nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
| ) | |||
| num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
| self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
| self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
| self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.conv2(x) | |||
| x = torch.flatten(x, 1) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| class SymbolNetAutoencoder(nn.Module): | |||
| def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
| super(SymbolNetAutoencoder, self).__init__() | |||
| self.base_model = SymbolNet(num_classes, image_size) | |||
| self.softmax = nn.Softmax(dim=1) | |||
| self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
| self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||
| def forward(self, x): | |||
| x = self.base_model(x) | |||
| # x = self.softmax(x) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| return x | |||
| @@ -0,0 +1,83 @@ | |||
| :- use_module(library(apply)). | |||
| :- use_module(library(lists)). | |||
| % :- use_module(library(tabling)). | |||
| % :- table valid_rules/2, op_rule/2. | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% DCG parser for equations | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% symbols to be mapped | |||
| digit(1). | |||
| digit(0). | |||
| % digits | |||
| digits([D]) --> [D], { digit(D) }. % empty list [] is not a digit | |||
| digits([D | T]) --> [D], !, digits(T), { digit(D) }. | |||
| digits(X):- | |||
| phrase(digits(X), X). | |||
| % More integrity constraints 1: | |||
| % This two clauses forbid the first digit to be 0. | |||
| % You may uncomment them to prune the search space | |||
| % length(X, L), | |||
| % (L > 1 -> X \= [0 | _]; true). | |||
| % Equation definition | |||
| eq_arg([D]) --> [D], { \+ D == '+', \+ D == '=' }. | |||
| eq_arg([D | T]) --> [D], !, eq_arg(T), { \+ D == '+', \+ D == '=' }. | |||
| equation(eq(X, Y, Z)) --> | |||
| eq_arg(X), [+], eq_arg(Y), [=], eq_arg(Z). | |||
| % More integrity constraints 2: | |||
| % This clause restricts the length of arguments to be sane, | |||
| % You may uncomment them to prune the search space | |||
| % { length(X, LX), length(Y, LY), length(Z, LZ), | |||
| % LZ =< max(LX, LY) + 1, LZ >= max(LX, LY) }. | |||
| parse_eq(List_of_Terms, Eq) :- | |||
| phrase(equation(Eq), List_of_Terms). | |||
| %%%%%%%%%%%%%%%%%%%%%% | |||
| %% Bit-wise operation | |||
| %%%%%%%%%%%%%%%%%%%%%% | |||
| % Abductive calculation with given pseudo-labels, abduces pseudo-labels as well as operation rules | |||
| calc(Rules, Pseudo) :- | |||
| calc([], Rules, Pseudo). | |||
| calc(Rules0, Rules1, Pseudo) :- | |||
| parse_eq(Pseudo, eq(X,Y,Z)), | |||
| bitwise_calc(Rules0, Rules1, X, Y, Z). | |||
| % Bit-wise calculation that handles carrying | |||
| bitwise_calc(Rules, Rules1, X, Y, Z) :- | |||
| reverse(X, X1), reverse(Y, Y1), reverse(Z, Z1), | |||
| bitwise_calc_r(Rules, Rules1, X1, Y1, Z1), | |||
| maplist(digits, [X,Y,Z]). | |||
| bitwise_calc_r(Rs, Rs, [], Y, Y). | |||
| bitwise_calc_r(Rs, Rs, X, [], X). | |||
| bitwise_calc_r(Rules, Rules1, [D1 | X], [D2 | Y], [D3 | Z]) :- | |||
| abduce_op_rule(my_op([D1],[D2],Sum), Rules, Rules2), | |||
| ((Sum = [D3], Carry = []); (Sum = [C, D3], Carry = [C])), | |||
| bitwise_calc_r(Rules2, Rules3, X, Carry, X_carried), | |||
| bitwise_calc_r(Rules3, Rules1, X_carried, Y, Z). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%% | |||
| % Abduce operation rules | |||
| %%%%%%%%%%%%%%%%%%%%%%%%% | |||
| % Get an existed rule | |||
| abduce_op_rule(R, Rules, Rules) :- | |||
| member(R, Rules). | |||
| % Add a new rule | |||
| abduce_op_rule(R, Rules, [R|Rules]) :- | |||
| op_rule(R), | |||
| valid_rules(Rules, R). | |||
| % Integrity Constraints | |||
| valid_rules([], _). | |||
| valid_rules([my_op([X1],[Y1],_)|Rs], my_op([X],[Y],Z)) :- | |||
| op_rule(my_op([X],[Y],Z)), | |||
| [X,Y] \= [X1,Y1], | |||
| [X,Y] \= [Y1,X1], | |||
| valid_rules(Rs, my_op([X],[Y],Z)). | |||
| valid_rules([my_op([Y],[X],Z)|Rs], my_op([X],[Y],Z)) :- | |||
| X \= Y, | |||
| valid_rules(Rs, my_op([X],[Y],Z)). | |||
| op_rule(my_op([X],[Y],[Z])) :- digit(X), digit(Y), digit(Z). | |||
| op_rule(my_op([X],[Y],[Z1,Z2])) :- digit(X), digit(Y), digits([Z1,Z2]). | |||
| @@ -0,0 +1,3 @@ | |||
| from .reasoning import HedKB, HedReasoner | |||
| __all__ = ["HedKB", "HedReasoner"] | |||
| @@ -0,0 +1,84 @@ | |||
| :- ensure_loaded(['BK.pl']). | |||
| :- thread_setconcurrency(_, 8). | |||
| :- use_module(library(thread)). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% For propositionalisation | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| eval_inst_feature(Ex, Feature):- | |||
| eval_eq(Ex, Feature). | |||
| %% Evaluate instance given feature | |||
| eval_eq(Ex, Feature):- | |||
| parse_eq(Ex, eq(X,Y,Z)), | |||
| bitwise_calc(Feature,_,X,Y,Z), !. | |||
| %%%%%%%%%%%%%% | |||
| %% Abduction | |||
| %%%%%%%%%%%%%% | |||
| % Make abduction when given samples that have been interpreted as pseudo-labels | |||
| abduce(Exs, Delta_C) :- | |||
| abduce(Exs, [], Delta_C). | |||
| abduce([], Delta_C, Delta_C). | |||
| abduce([E|Exs], Delta_C0, Delta_C1) :- | |||
| calc(Delta_C0, Delta_C2, E), | |||
| abduce(Exs, Delta_C2, Delta_C1). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% Abduce pseudo-labels only | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| abduce_consistent_insts(Exs):- | |||
| abduce(Exs, _), !. | |||
| % (Experimental) Uncomment to use parallel abduction | |||
| % abduce_consistent_exs_concurrent(Exs), !. | |||
| logic_forward(Exs, X) :- abduce_consistent_insts(Exs) -> X = true ; X = false. | |||
| logic_forward(Exs) :- abduce_consistent_insts(Exs). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% Abduce Delta_C given pseudo-labels | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| consistent_inst_feature(Exs, Delta_C):- | |||
| abduce(Exs, Delta_C), !. | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% (Experimental) Parallel abduction | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| abduce_consistent_exs_concurrent(Exs) :- | |||
| % Split the current data batch into grounding samples and variable samples (which need to be revised) | |||
| split_exs(Exs, Ground_Exs, Var_Exs), | |||
| % Find the simplest Delta_C for grounding samples. | |||
| abduce(Ground_Exs, Ground_Delta_C), !, | |||
| % Extend Ground Delta_C into all possible variations | |||
| extend_op_rule(Ground_Delta_C, Possible_Deltas), | |||
| % Concurrently abduce the variable samples | |||
| maplist(append([abduce2, Var_Exs, Ground_Exs]), [[Possible_Deltas]], Call_List), | |||
| maplist(=.., Goals, Call_List), | |||
| % writeln(Goals), | |||
| first_solution(Var_Exs, Goals, [local(inf)]). | |||
| split_exs([],[],[]). | |||
| split_exs([E | Exs], [E | G_Exs], V_Exs):- | |||
| ground(E), !, | |||
| split_exs(Exs, G_Exs, V_Exs). | |||
| split_exs([E | Exs], G_Exs, [E | V_Exs]):- | |||
| split_exs(Exs, G_Exs, V_Exs). | |||
| :- table extend_op_rule/2. | |||
| extend_op_rule(Rules, Rules) :- | |||
| length(Rules, 4). | |||
| extend_op_rule(Rules, Ext) :- | |||
| op_rule(R), | |||
| valid_rules(Rules, R), | |||
| extend_op_rule([R|Rules], Ext). | |||
| % abduction without learning new Delta_C (Because they have been extended!) | |||
| abduce2([], _, _). | |||
| abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
| % abduce by finding ground samples | |||
| member(E, Ground_Exs), | |||
| abduce2(Exs, Ground_Exs, Delta_C). | |||
| abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
| eval_inst_feature(E, Delta_C), | |||
| abduce2(Exs, Ground_Exs, Delta_C). | |||
| @@ -0,0 +1,102 @@ | |||
| import math | |||
| import os | |||
| import numpy as np | |||
| from abl.reasoning import PrologKB, Reasoner | |||
| from abl.utils import reform_list | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| class HedKB(PrologKB): | |||
| def __init__( | |||
| self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl") | |||
| ): | |||
| pl_file = pl_file.replace("\\", "/") | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| self.learned_rules = {} | |||
| def consist_rule(self, exs, rules): | |||
| rules = str(rules).replace("'", "") | |||
| return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
| def abduce_rules(self, pred_res): | |||
| prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]["X"] | |||
| rules = [rule.value for rule in prolog_rules] | |||
| return rules | |||
| class HedReasoner(Reasoner): | |||
| def revise_at_idx(self, data_example): | |||
| revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0] | |||
| candidate = self.kb.revise_at_idx( | |||
| data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx | |||
| ) | |||
| return candidate | |||
| def zoopt_budget(self, symbol_num): | |||
| return 200 | |||
| def zoopt_score(self, symbol_num, data_example, sol, get_score=True): | |||
| revision_flag = reform_list( | |||
| list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label | |||
| ) | |||
| data_example.revision_flag = revision_flag | |||
| lefted_idxs = [i for i in range(len(data_example.pred_idx))] | |||
| candidate_size = [] | |||
| max_consistent_idxs = [] | |||
| while lefted_idxs: | |||
| idxs = [] | |||
| idxs.append(lefted_idxs.pop(0)) | |||
| max_candidate_idxs = [] | |||
| found = False | |||
| for idx in range(-1, len(data_example.pred_idx)): | |||
| if (idx not in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidates, _ = self.revise_at_idx(data_example[idxs]) | |||
| if len(candidates) == 0: | |||
| if len(idxs) > 1: | |||
| idxs.pop() | |||
| else: | |||
| if len(idxs) > len(max_candidate_idxs): | |||
| found = True | |||
| max_candidate_idxs = idxs.copy() | |||
| removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
| if found: | |||
| removed.insert(0, idxs[0]) | |||
| candidate_size.append(len(removed)) | |||
| max_consistent_idxs = max_candidate_idxs.copy() | |||
| lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
| candidate_size.sort() | |||
| score = 0 | |||
| for i in range(0, len(candidate_size)): | |||
| score -= math.exp(-i) * candidate_size[i] | |||
| if get_score: | |||
| return score | |||
| else: | |||
| return max_consistent_idxs | |||
| def abduce(self, data_example): | |||
| symbol_num = data_example.elements_num("pred_pseudo_label") | |||
| max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||
| solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | |||
| max_candidate_idxs = self.zoopt_score(symbol_num, data_example, solution, get_score=False) | |||
| abduced_pseudo_label = [[] for _ in range(len(data_example))] | |||
| if len(max_candidate_idxs) > 0: | |||
| candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs]) | |||
| for i, idx in enumerate(max_candidate_idxs): | |||
| abduced_pseudo_label[idx] = candidates[0][i] | |||
| data_example.abduced_pseudo_label = abduced_pseudo_label | |||
| return abduced_pseudo_label | |||
| def abduce_rules(self, pred_res): | |||
| return self.kb.abduce_rules(pred_res) | |||
| @@ -0,0 +1,2 @@ | |||
| abl | |||
| gdown | |||
| @@ -0,0 +1,65 @@ | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.utils.data.sampler as sampler | |||
| class InfiniteSampler(sampler.Sampler): | |||
| def __init__(self, num_examples, batch_size=1): | |||
| self.num_examples = num_examples | |||
| self.batch_size = batch_size | |||
| def __iter__(self): | |||
| while True: | |||
| order = np.random.permutation(self.num_examples) | |||
| for i in range(self.num_examples): | |||
| yield order[i : i + self.batch_size] | |||
| i += self.batch_size | |||
| def __len__(self): | |||
| return None | |||
| def gen_mappings(chars, symbs): | |||
| n_char = len(chars) | |||
| n_symbs = len(symbs) | |||
| if n_char != n_symbs: | |||
| print("Characters and symbols size dosen't match.") | |||
| return | |||
| from itertools import permutations | |||
| mappings = [] | |||
| # returned mappings | |||
| perms = permutations(symbs) | |||
| for p in perms: | |||
| if p.index(1) < p.index(0): | |||
| continue | |||
| mappings.append(dict(zip(chars, list(p)))) | |||
| return mappings | |||
| def mapping_res(original_pred_res, m): | |||
| return [[m[symbol] for symbol in formula] for formula in original_pred_res] | |||
| def remapping_res(pred_res, m): | |||
| remapping = {} | |||
| for key, value in m.items(): | |||
| remapping[value] = key | |||
| return [[remapping[symbol] for symbol in formula] for formula in pred_res] | |||
| def extract_feature(img): | |||
| extractor = nn.AvgPool2d(2, stride=2) | |||
| feature_map = np.array(extractor(torch.Tensor(img))) | |||
| return feature_map.reshape((-1,)) | |||
| def reduce_dimension(data): | |||
| for truth_value in [0, 1]: | |||
| for equation_len in range(5, 27): | |||
| equations = data[truth_value][equation_len] | |||
| reduced_equations = [ | |||
| [extract_feature(symbol_img) for symbol_img in equation] for equation in equations | |||
| ] | |||
| data[truth_value][equation_len] = reduced_equations | |||
| @@ -0,0 +1,44 @@ | |||
| # Handwritten Formula | |||
| This example shows a simple implementation of [Handwritten Formula](https://arxiv.org/abs/2006.06649) task, where handwritten images of decimal formulas and their computed results are given, alongwith a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators '+', '-', '×', '÷') of handwritten images and accurately determine their results. | |||
| ## Run | |||
| ```bash | |||
| pip install -r requirements.txt | |||
| python main.py | |||
| ``` | |||
| ## Usage | |||
| ```bash | |||
| usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] | |||
| [--batch-size BATCH_SIZE] | |||
| [--loops LOOPS] [--segment_size SEGMENT_SIZE] | |||
| [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION] | |||
| [--require-more-revision REQUIRE_MORE_REVISION] | |||
| [--ground] [--max-err MAX_ERR] | |||
| Handwritten Formula example | |||
| optional arguments: | |||
| -h, --help show this help message and exit | |||
| --no-cuda disables CUDA training | |||
| --epochs EPOCHS number of epochs in each learning loop iteration | |||
| (default : 1) | |||
| --lr LR base model learning rate (default : 0.001) | |||
| --batch-size BATCH_SIZE | |||
| base model batch size (default : 32) | |||
| --loops LOOPS number of loop iterations (default : 5) | |||
| --segment_size SEGMENT_SIZE | |||
| segment size (default : 1/3) | |||
| --save_interval SAVE_INTERVAL | |||
| save interval (default : 1) | |||
| --max-revision MAX_REVISION | |||
| maximum revision in reasoner (default : -1) | |||
| --require-more-revision REQUIRE_MORE_REVISION | |||
| require more revision in reasoner (default : 0) | |||
| --ground use GroundKB (default: False) | |||
| --max-err MAX_ERR max tolerance during abductive reasoning (default : 1e-10) | |||
| ``` | |||
| @@ -0,0 +1,3 @@ | |||
| from .get_dataset import get_dataset | |||
| __all__ = ["get_dataset"] | |||
| @@ -0,0 +1,67 @@ | |||
| import json | |||
| import os | |||
| import zipfile | |||
| import gdown | |||
| from PIL import Image | |||
| from torchvision.transforms import transforms | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) | |||
| def download_and_unzip(url, zip_file_name): | |||
| try: | |||
| gdown.download(url, zip_file_name) | |||
| with zipfile.ZipFile(zip_file_name, "r") as zip_ref: | |||
| zip_ref.extractall(CURRENT_DIR) | |||
| os.remove(zip_file_name) | |||
| except Exception as e: | |||
| if os.path.exists(zip_file_name): | |||
| os.remove(zip_file_name) | |||
| raise Exception( | |||
| f"An error occurred during download or unzip: {e}. Instead, you can download " | |||
| + f"the dataset from {url} and unzip it in 'examples/hwf/datasets' folder" | |||
| ) | |||
| def get_dataset(train=True, get_pseudo_label=False): | |||
| data_dir = CURRENT_DIR + "/data" | |||
| if not os.path.exists(data_dir): | |||
| print("Dataset not exist, downloading it...") | |||
| url = "https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download" | |||
| download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip")) | |||
| print("Download and extraction complete.") | |||
| if train: | |||
| file = os.path.join(data_dir, "expr_train.json") | |||
| else: | |||
| file = os.path.join(data_dir, "expr_test.json") | |||
| X = [] | |||
| pseudo_label = [] if get_pseudo_label else None | |||
| Y = [] | |||
| img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") | |||
| with open(file) as f: | |||
| data = json.load(f) | |||
| for idx in range(len(data)): | |||
| imgs = [] | |||
| if get_pseudo_label: | |||
| imgs_pseudo_label = [] | |||
| for img_path in data[idx]["img_paths"]: | |||
| img = Image.open(img_dir + img_path).convert("L") | |||
| img = img_transform(img) | |||
| imgs.append(img) | |||
| if get_pseudo_label: | |||
| label_mappings = {"times": "*", "div": "/"} | |||
| label = img_path.split("/")[0] | |||
| label = label_mappings.get(label, label) | |||
| imgs_pseudo_label.append(label) | |||
| X.append(imgs) | |||
| if get_pseudo_label: | |||
| pseudo_label.append(imgs_pseudo_label) | |||
| Y.append(data[idx]["res"]) | |||
| return X, pseudo_label, Y | |||
| @@ -0,0 +1,454 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "# Handwritten Formula (HWF)\n", | |||
| "\n", | |||
| "This notebook shows an implementation of [Handwritten Formula](https://arxiv.org/abs/2006.06649). In this task, handwritten images of decimal formulas and their computed results are given, alongwith a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators '+', '-', '×', '÷') of handwritten images and accurately determine their results.\n", | |||
| "\n", | |||
| "Intuitively, we first use a machine learning model (learning part) to convert the input images to symbols (we call them pseudo-labels), and then use the knowledge base (reasoning part) to calculate the results of these symbols. Since we do not have ground-truth of the symbols, in Abductive Learning, the reasoning part will leverage domain knowledge and revise the initial symbols yielded by the learning part through abductive reasoning. This process enables us to further update the machine learning model." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Import necessary libraries and modules\n", | |||
| "import os.path as osp\n", | |||
| "\n", | |||
| "import matplotlib.pyplot as plt\n", | |||
| "import numpy as np\n", | |||
| "import torch\n", | |||
| "import torch.nn as nn\n", | |||
| "\n", | |||
| "from abl.bridge import SimpleBridge\n", | |||
| "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
| "from abl.learning import ABLModel, BasicNN\n", | |||
| "from abl.reasoning import KBBase, Reasoner\n", | |||
| "from abl.utils import ABLLogger, print_log\n", | |||
| "\n", | |||
| "from datasets import get_dataset\n", | |||
| "from models.nn import SymbolNet" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "## Working with Data\n", | |||
| "\n", | |||
| "First, we get the training and testing datasets:" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "train_data = get_dataset(train=True, get_pseudo_label=True)\n", | |||
| "test_data = get_dataset(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "Both `train_data` and `test_data` have the same structures: tuples with three components: X (list where each element is a list of images), gt_pseudo_label (list where each element is a list of symbols, i.e., pseudo-labels) and Y (list where each element is the computed result). The length and structures of datasets are illustrated as follows.\n", | |||
| "\n", | |||
| "Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", | |||
| "print()\n", | |||
| "train_X, train_gt_pseudo_label, train_Y = train_data\n", | |||
| "print(\n", | |||
| " f\"Length of X, gt_pseudo_label, Y in train_data: \"\n", | |||
| " + f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\"\n", | |||
| ")\n", | |||
| "test_X, test_gt_pseudo_label, test_Y = test_data\n", | |||
| "print(\n", | |||
| " f\"Length of X, gt_pseudo_label, Y in test_data: \"\n", | |||
| " + f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\"\n", | |||
| ")\n", | |||
| "print()\n", | |||
| "\n", | |||
| "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n", | |||
| "print(\n", | |||
| " f\"X is a {type(train_X).__name__}, \"\n", | |||
| " + f\"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.\"\n", | |||
| ")\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \"\n", | |||
| " + f\"with each element being a {type(gt_pseudo_label_0).__name__} \"\n", | |||
| " + f\"of {type(gt_pseudo_label_0[0]).__name__}.\"\n", | |||
| ")\n", | |||
| "print(f\"Y is a {type(train_Y).__name__}, \" + f\"with each element being a {type(Y_0).__name__}.\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "The ith element of X, gt_pseudo_label, and Y together constitute the ith data example. Here we use two of them (the 1001st and the 3001st) as illstrations:" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000]\n", | |||
| "print(f\"X in the 1001st data example (a list of images):\")\n", | |||
| "for i, x in enumerate(X_1000):\n", | |||
| " plt.subplot(1, len(X_1000), i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}\"\n", | |||
| ")\n", | |||
| "print(f\"Y in the 1001st data example (the computed result): {Y_1000}\")\n", | |||
| "print()\n", | |||
| "X_3000, gt_pseudo_label_3000, Y_3000 = train_X[3000], train_gt_pseudo_label[3000], train_Y[3000]\n", | |||
| "print(f\"X in the 3001st data example (a list of images):\")\n", | |||
| "for i, x in enumerate(X_3000):\n", | |||
| " plt.subplot(1, len(X_3000), i + 1)\n", | |||
| " plt.axis(\"off\")\n", | |||
| " plt.imshow(x.squeeze(), cmap=\"gray\")\n", | |||
| "plt.show()\n", | |||
| "print(\n", | |||
| " f\"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}\"\n", | |||
| ")\n", | |||
| "print(f\"Y in the 3001st data example (the computed result): {Y_3000}\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "Note: The symbols in the HWF dataset can be one of digits or operators '+', '-', '×', '÷'. \n", | |||
| "\n", | |||
| "Note: We may see that, in the 1001st data example, the length of the formula is 3, while in the 3001st data example, the length of the formula is 5. In the HWF dataset, the length of the formula varies from 1 to 7." | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "## Building the Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "To build the learning part, we need to first build a machine learning base model. We use SymbolNet, and encapsulate it within a `BasicNN` object to create the base model. `BasicNN` is a class that encapsulates a PyTorch model, transforming it into a base model with an sklearn-style interface. " | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 13 classes\n", | |||
| "cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))\n", | |||
| "loss_fn = nn.CrossEntropyLoss()\n", | |||
| "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "\n", | |||
| "base_model = BasicNN(\n", | |||
| " model=cls,\n", | |||
| " loss_fn=loss_fn,\n", | |||
| " optimizer=optimizer,\n", | |||
| " device=device,\n", | |||
| " batch_size=128,\n", | |||
| " num_epochs=3,\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "`BasicNN` offers methods like `predict` and `predict_prob`, which are used to predict the class index and the probabilities of each class for images. As shown below:" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)]\n", | |||
| "pred_idx = base_model.predict(X=data_instances)\n", | |||
| "print(\n", | |||
| " f\"Predicted class index for a batch of 32 instances: \"\n", | |||
| " + f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\"\n", | |||
| ")\n", | |||
| "pred_prob = base_model.predict_proba(X=data_instances)\n", | |||
| "print(\n", | |||
| " f\"Predicted class probabilities for a batch of 32 instances: \"\n", | |||
| " + f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\"\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "However, the base model built above deals with instance-level data (i.e., individual images), and can not directly deal with example-level data (i.e., a list of images comprising the formula). Therefore, we wrap the base model into `ABLModel`, which enables the learning part to train, test, and predict on example-level data." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "model = ABLModel(base_model)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "As an illustration, consider this example of training on example-level data using the `predict` method in `ABLModel`. In this process, the method accepts data examples as input and outputs the class labels and the probabilities of each class for all instances within these data examples." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "from abl.data.structures import ListData\n", | |||
| "\n", | |||
| "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", | |||
| "data_examples = ListData()\n", | |||
| "# We use the first 1001st and 3001st data examples in the training set as an illustration\n", | |||
| "data_examples.X = [X_1000, X_3000]\n", | |||
| "data_examples.gt_pseudo_label = [gt_pseudo_label_1000, gt_pseudo_label_3000]\n", | |||
| "data_examples.Y = [Y_1000, Y_3000]\n", | |||
| "\n", | |||
| "# Perform prediction on the two data examples\n", | |||
| "# Remind that, in the 1001st data example, the length of the formula is 3,\n", | |||
| "# while in the 3001st data example, the length of the formula is 5.\n", | |||
| "pred_label, pred_prob = model.predict(data_examples)[\"label\"], model.predict(data_examples)[\"prob\"]\n", | |||
| "print(\n", | |||
| " f\"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \\n\"\n", | |||
| " + f\"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, \"\n", | |||
| " + f\"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\\n\"\n", | |||
| ")\n", | |||
| "print(\n", | |||
| " f\"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \\n\"\n", | |||
| " f\"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, \"\n", | |||
| " + f\"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.\"\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "## Building the Reasoning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "In the reasoning part, we first build a knowledge base which contain information on how to perform addition operations. We build it by creating a subclass of `KBBase`. In the derived subclass, we initialize the `pseudo_label_list` parameter specifying list of possible pseudo-labels, and override the `logic_forward` function defining how to perform (deductive) reasoning." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "class HwfKB(KBBase):\n", | |||
| " def __init__(\n", | |||
| " self, pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"]\n", | |||
| " ):\n", | |||
| " super().__init__(pseudo_label_list)\n", | |||
| "\n", | |||
| " def _valid_candidate(self, formula):\n", | |||
| " if len(formula) % 2 == 0:\n", | |||
| " return False\n", | |||
| " for i in range(len(formula)):\n", | |||
| " if i % 2 == 0 and formula[i] not in [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]:\n", | |||
| " return False\n", | |||
| " if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"*\", \"/\"]:\n", | |||
| " return False\n", | |||
| " return True\n", | |||
| "\n", | |||
| " # Implement the deduction function\n", | |||
| " def logic_forward(self, formula):\n", | |||
| " if not self._valid_candidate(formula):\n", | |||
| " return np.inf\n", | |||
| " return eval(\"\".join(formula))\n", | |||
| "\n", | |||
| "\n", | |||
| "kb = HwfKB()" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "The knowledge base can perform logical reasoning (both deductive reasoning and abductive reasoning). Below is an example of performing (deductive) reasoning, and users can refer to [Documentation]() for details of abductive reasoning." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "pseudo_labels = [\"1\", \"-\", \"2\", \"*\", \"5\"]\n", | |||
| "reasoning_result = kb.logic_forward(pseudo_labels)\n", | |||
| "print(f\"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "Note: In addition to building a knowledge base based on `KBBase`, we can also establish a knowledge base with a ground KB using `GroundKB`. The corresponding code can be found in the `main.py` file. Those interested are encouraged to examine it for further insights.\n", | |||
| "\n", | |||
| "Note: Also, when building the knowledge base, we can also set the `max_err` parameter during initialization, which is shown in the `main.py` file. This parameter specifies the upper tolerance limit when comparing the similarity between the reasoning result of pseudo-labels and the ground truth during abductive reasoning, with a default value of 1e-10." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "Then, we create a reasoner by instantiating the class ``Reasoner``. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "reasoner = Reasoner(kb)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data example and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n", | |||
| "\n", | |||
| "Note: Also, during process of inconsistency minimization, we can leverage [ZOOpt library](https://github.com/polixir/ZOOpt) for acceleration. Options for this are also available in `main.py`. Those interested are encouraged to explore these features." | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "## Building Evaluation Metrics" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolAccuracy` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "metric_list = [SymbolAccuracy(prefix=\"hwf\"), ReasoningMetric(kb=kb, prefix=\"hwf\")]" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "## Bridge Learning and Reasoning\n", | |||
| "\n", | |||
| "Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "bridge = SimpleBridge(model, reasoner, metric_list)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "Perform training and testing by invoking the `train` and `test` methods of `SimpleBridge`." | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Build logger\n", | |||
| "print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n", | |||
| "log_dir = ABLLogger.get_current_instance().log_dir\n", | |||
| "weights_dir = osp.join(log_dir, \"weights\")\n", | |||
| "\n", | |||
| "bridge.train(train_data, train_data, loops=3, segment_size=1000, save_dir=weights_dir)\n", | |||
| "bridge.test(test_data)" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "abl", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.18" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||
| "interpreter": { | |||
| "hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e" | |||
| } | |||
| } | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||
| @@ -0,0 +1,187 @@ | |||
| import argparse | |||
| import os.path as osp | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from abl.bridge import SimpleBridge | |||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.reasoning import GroundKB, KBBase, Reasoner | |||
| from abl.utils import ABLLogger, print_log | |||
| from datasets import get_dataset | |||
| from models.nn import SymbolNet | |||
| class HwfKB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], | |||
| max_err=1e-10, | |||
| ): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: | |||
| return False | |||
| return True | |||
| # Implement the deduction function | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return np.inf | |||
| return eval("".join(formula)) | |||
| class HwfGroundKB(GroundKB): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], | |||
| GKB_len_list=[1, 3, 5, 7], | |||
| max_err=1e-10, | |||
| ): | |||
| super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: | |||
| return False | |||
| return True | |||
| # Implement the deduction function | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return np.inf | |||
| return eval("".join(formula)) | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="Handwritten Formula example") | |||
| parser.add_argument( | |||
| "--no-cuda", action="store_true", default=False, help="disables CUDA training" | |||
| ) | |||
| parser.add_argument( | |||
| "--epochs", | |||
| type=int, | |||
| default=3, | |||
| help="number of epochs in each learning loop iteration (default : 3)", | |||
| ) | |||
| parser.add_argument( | |||
| "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | |||
| ) | |||
| parser.add_argument( | |||
| "--batch-size", type=int, default=128, help="base model batch size (default : 128)" | |||
| ) | |||
| parser.add_argument( | |||
| "--loops", type=int, default=5, help="number of loop iterations (default : 5)" | |||
| ) | |||
| parser.add_argument( | |||
| "--segment_size", type=int, default=1000, help="segment size (default : 1000)" | |||
| ) | |||
| parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") | |||
| parser.add_argument( | |||
| "--max-revision", | |||
| type=int, | |||
| default=-1, | |||
| help="maximum revision in reasoner (default : -1)", | |||
| ) | |||
| parser.add_argument( | |||
| "--require-more-revision", | |||
| type=int, | |||
| default=0, | |||
| help="require more revision in reasoner (default : 0)", | |||
| ) | |||
| parser.add_argument( | |||
| "--ground", action="store_true", default=False, help="use GroundKB (default: False)" | |||
| ) | |||
| parser.add_argument( | |||
| "--max-err", | |||
| type=float, | |||
| default=1e-10, | |||
| help="max tolerance during abductive reasoning (default : 1e-10)", | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HWF example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr) | |||
| use_cuda = not args.no_cuda and torch.cuda.is_available() | |||
| device = torch.device("cuda" if use_cuda else "cpu") | |||
| # Build BasicNN | |||
| base_model = BasicNN( | |||
| cls, | |||
| loss_fn, | |||
| optimizer, | |||
| device=device, | |||
| batch_size=args.batch_size, | |||
| num_epochs=args.epochs, | |||
| ) | |||
| # Build ABLModel | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.ground: | |||
| kb = HwfGroundKB() | |||
| else: | |||
| kb = HwfKB() | |||
| # Create reasoner | |||
| reasoner = Reasoner( | |||
| kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision | |||
| ) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| # Train and Test | |||
| bridge.train( | |||
| train_data, | |||
| loops=args.loops, | |||
| segment_size=args.segment_size, | |||
| save_interval=args.save_interval, | |||
| save_dir=weights_dir, | |||
| ) | |||
| bridge.test(test_data) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,33 @@ | |||
| import torch | |||
| from torch import nn | |||
| class SymbolNet(nn.Module): | |||
| def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
| super(SymbolNet, self).__init__() | |||
| self.conv1 = nn.Sequential( | |||
| nn.Conv2d(1, 32, 5, stride=1), | |||
| nn.ReLU(), | |||
| nn.MaxPool2d(kernel_size=2, stride=2), | |||
| nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
| ) | |||
| self.conv2 = nn.Sequential( | |||
| nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
| nn.ReLU(), | |||
| nn.MaxPool2d(kernel_size=2, stride=2), | |||
| nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
| ) | |||
| num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
| self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
| self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
| self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.conv2(x) | |||
| x = torch.flatten(x, 1) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| x = self.fc3(x) | |||
| return x | |||