| Author | SHA1 | Message | Date |
|---|---|---|---|
|
|
559e2ac91d | [DOC] update readme doc | 7 months ago |
|
|
d0df086e6e | [FIX] add emoji in title | 7 months ago |
|
|
51274d316d | [FIX] change font size | 7 months ago |
|
|
e86f12e1f1 | [FIX] fix placeholder | 7 months ago |
|
|
695f9d2176 | [ENH] add contributors | 7 months ago |
|
|
8db8c1b4a5 | [FIX] pass flake8 | 7 months ago |
|
|
79a26a0dc5
|
Merge pull request #12 from wnqn1597/examples
add BDD-OIA example |
7 months ago |
|
|
c9c82ed8e5 | update readme | 7 months ago |
|
|
c4c85dd02a | add BDD-OIA example | 7 months ago |
| @@ -2,13 +2,13 @@ | |||||
| <img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/logo.png" width="180"> | <img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/logo.png" width="180"> | ||||
| [](https://pypi.org/project/ablkit/) [](https://pypi.org/project/ablkit/) [](https://ablkit.readthedocs.io/en/latest/?badge=latest) [](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [](https://github.com/psf/black) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) | |||||
| [](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [](https://img.shields.io/github/last-commit/AbductiveLearning/ablkit) [](https://pypi.org/project/ablkit/) [](https://pypi.org/project/ablkit/) [](https://ablkit.readthedocs.io/en/latest/?badge=latest) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [](https://github.com/psf/black) [](https://pypi.org/project/ablkit/) | |||||
| [📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [📚Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new) | |||||
| [📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [🧪Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new) | |||||
| </div> | </div> | ||||
| # ABLkit: A Toolkit for Abductive Learning | |||||
| # 🧰 ABLkit: A Toolkit for Abductive Learning 📊📐 | |||||
| **ABLkit** is an efficient Python toolkit for [**Abductive Learning (ABL)**](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). ABL is a novel paradigm that integrates machine learning and logical reasoning in a unified framework. It is suitable for tasks where both data and (logical) domain knowledge are available. | **ABLkit** is an efficient Python toolkit for [**Abductive Learning (ABL)**](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). ABL is a novel paradigm that integrates machine learning and logical reasoning in a unified framework. It is suitable for tasks where both data and (logical) domain knowledge are available. | ||||
| @@ -28,7 +28,7 @@ ABLkit encapsulates advanced ABL techniques, providing users with an efficient a | |||||
| <img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/ABLkit.png" alt="ABLkit" style="width: 80%;"/> | <img src="https://raw.githubusercontent.com/AbductiveLearning/ABLkit/main/docs/_static/img/ABLkit.png" alt="ABLkit" style="width: 80%;"/> | ||||
| </p> | </p> | ||||
| ## Installation | |||||
| ## 🛠️ Installation | |||||
| ### Install from PyPI | ### Install from PyPI | ||||
| @@ -60,7 +60,7 @@ sudo apt-get install swi-prolog | |||||
| For Windows and Mac users, please refer to the [SWI-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). | For Windows and Mac users, please refer to the [SWI-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). | ||||
| ## Quick Start | |||||
| ## ⚡ Quick Start | |||||
| We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. | We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. | ||||
| @@ -184,7 +184,7 @@ bridge.test(test_data) | |||||
| To explore detailed tutorials and information, please refer to: [Documentation on Read the Docs](https://ablkit.readthedocs.io/en/latest/index.html). | To explore detailed tutorials and information, please refer to: [Documentation on Read the Docs](https://ablkit.readthedocs.io/en/latest/index.html). | ||||
| ## Examples | |||||
| ## 🧪 Examples | |||||
| We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file. | We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file. | ||||
| @@ -192,8 +192,9 @@ We provide several examples in `examples/`. Each example is stored in a separate | |||||
| + [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf) | + [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf) | ||||
| + [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed) | + [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed) | ||||
| + [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo) | + [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo) | ||||
| + [BDD-OIA](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/bdd_oia) | |||||
| ## References | |||||
| ## 📚 References | |||||
| For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichina.com/en/2019/076101.pdf) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). | For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichina.com/en/2019/076101.pdf) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). | ||||
| @@ -220,7 +221,7 @@ For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichi | |||||
| } | } | ||||
| ``` | ``` | ||||
| ## Citation | |||||
| ## 📝 Citation | |||||
| To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7). | To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7). | ||||
| @@ -234,4 +235,46 @@ To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://j | |||||
| pages = {186354}, | pages = {186354}, | ||||
| year = {2024} | year = {2024} | ||||
| } | } | ||||
| ``` | |||||
| ``` | |||||
| ## ✨ Contributors | |||||
| We would like to thank the following contributors for their efforts on this project: <sub><i>(*: current maintainer)</i></sub> | |||||
| <table> | |||||
| <tr> | |||||
| <td align="center"> | |||||
| <a href="https://github.com/Tony-HYX"> | |||||
| <img src="https://avatars.githubusercontent.com/u/34394824?V=4" width="100px;" alt=""/> | |||||
| <br /> | |||||
| Yu-Xuan Huang | |||||
| </a> | |||||
| </td> | |||||
| <td align="center"> | |||||
| <a href="https://github.com/troyyyyy"> | |||||
| <img src="https://avatars.githubusercontent.com/u/49091847?v=4" width="100px;" alt=""/> | |||||
| <br /> | |||||
| Wen-Chao Hu | |||||
| </a>* | |||||
| </td> | |||||
| <td align="center"> | |||||
| <a href="https://github.com/WaTerminator"> | |||||
| <img src="https://avatars.githubusercontent.com/u/58843099?V=4" width="100px;" alt=""/> | |||||
| <br /> | |||||
| En-Hao Gao | |||||
| </a> | |||||
| </td> | |||||
| <td align="center"> | |||||
| <a href="https://github.com/snqn1597"> | |||||
| <img src="https://avatars.githubusercontent.com/u/98020642?V=4" width="100px;" alt=""/> | |||||
| <br /> | |||||
| Qi-Jie Li | |||||
| </a> | |||||
| </td> | |||||
| </tr> | |||||
| </table> | |||||
| We also thank the following users for their helpful suggestions and feedback: | |||||
| - [Hao-Yuan He](https://github.com/Hao-Yuan-He) | |||||
| - [Wang-Zhou Dai](https://github.com/haldai) | |||||
| @@ -180,8 +180,8 @@ class Reasoner: | |||||
| candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | ||||
| return avg_confidence_dist(data_example.pred_prob, candidates_idxs) | return avg_confidence_dist(data_example.pred_prob, candidates_idxs) | ||||
| else: | else: | ||||
| candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | |||||
| cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) | |||||
| candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | |||||
| cost_list = self.dist_func(data_example, candidates, candidates_idxs, reasoning_results) | |||||
| if len(cost_list) != len(candidates): | if len(cost_list) != len(candidates): | ||||
| raise ValueError( | raise ValueError( | ||||
| "The length of the array returned by dist_func must be equal to the number " | "The length of the array returned by dist_func must be equal to the number " | ||||
| @@ -0,0 +1,50 @@ | |||||
| # BDD-OIA | |||||
| This example shows an implementation of [BDD-OIA](https://twizwei.github.io/bddoia_project/) task. The BDD-OIA dataset comprises frames extracted from driving scene videos, which are utilized for autonomous driving predictions. Each frame is annotated with 4 binary labels, indicating the possible actions, namely $\textsf{move forward}$, $\textsf{stop}$, $\textsf{turn left}$, $\textsf{turn right}$. Each frame is also annotated with 21 intermediate binary concepts such as $\textsf{red light}$, $\textsf{road clear}$, etc., underlying the reasons for the possible actions. | |||||
| The objective is to predict possible actions for each frame. During training, we only make use of the label supervision, along with a knowledge base, which comprises information about the relations between concepts and actions, e.g., $\textsf{red light} \lor \textsf{traffic sign} \lor \textsf{obstacle} \implies \textsf{stop}$. The training set consists of 16,000 frames, while the test set contains 4,500 annotated data points. | |||||
| Before usage, the dataset was pre-processed by [Marconato et al. (2023)](https://proceedings.neurips.cc/paper_files/paper/2023/file/e560202b6e779a82478edb46c6f8f4dd-Paper-Conference.pdf) using a pretrained Faster-RCNN model on BDD-100k, in conjunction with the first module in CBM-AUC [(Sawada & Nakamura, 2022)](https://arxiv.org/abs/2202.01459), resulting in embeddings of dimension 2048. | |||||
| ## Run | |||||
| ```bash | |||||
| pip install -r requirements.txt | |||||
| cd dataset | |||||
| unzip dataset.zip | |||||
| cd .. | |||||
| python main.py | |||||
| ``` | |||||
| ## Usage | |||||
| ```bash | |||||
| usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] | |||||
| [--batch-size BATCH_SIZE] [--loops LOOPS] | |||||
| [--segment_size SEGMENT_SIZE] | |||||
| [--save_interval SAVE_INTERVAL] | |||||
| [--max-revision MAX_REVISION] | |||||
| [--require-more-revision REQUIRE_MORE_REVISION] | |||||
| BDD_OIA example | |||||
| optional arguments: | |||||
| -h, --help show this help message and exit | |||||
| --no-cuda disables CUDA training | |||||
| --epochs EPOCHS number of epochs in each learning loop iteration | |||||
| (default : 1) | |||||
| --lr LR base model learning rate (default : 0.002) | |||||
| --batch-size BATCH_SIZE | |||||
| base model batch size (default : 32) | |||||
| --loops LOOPS | |||||
| number of loop iterations (default : 2) | |||||
| --segment_size SEGMENT_SIZE | |||||
| segment size (default : 0.01) | |||||
| --save_interval SAVE_INTERVAL | |||||
| save interval (default : 1) | |||||
| --max-revision MAX_REVISION | |||||
| maximum revision in reasoner (default : 3) | |||||
| -require-more-revision REQUIRE_MORE_REVISION | |||||
| require more revision in reasoner (default : 3) | |||||
| ``` | |||||
| @@ -0,0 +1,24 @@ | |||||
| import numpy as np | |||||
| from typing import List, Any | |||||
| from ablkit.data import ListData | |||||
| from ablkit.bridge import SimpleBridge | |||||
| class BDDBridge(SimpleBridge): | |||||
| def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||||
| pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] | |||||
| pred_pseudo_label = [] | |||||
| for sub_list in pred_idx: | |||||
| sub_list = sub_list.squeeze() # 1 x nc -> nc | |||||
| pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list]) | |||||
| data_examples.pred_pseudo_label = pred_pseudo_label | |||||
| return data_examples.pred_pseudo_label | |||||
| def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: | |||||
| abduced_pseudo_label = data_examples.abduced_pseudo_label | |||||
| abduced_idx = [] | |||||
| for sub_list in abduced_pseudo_label: | |||||
| sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) | |||||
| abduced_idx.append(sub_list) | |||||
| data_examples.abduced_idx = abduced_idx | |||||
| return data_examples.abduced_idx | |||||
| @@ -0,0 +1,19 @@ | |||||
| import os | |||||
| import numpy as np | |||||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||||
| def get_dataset(fname, get_pseudo_label=True): | |||||
| fname = os.path.join(CURRENT_DIR, fname) | |||||
| data = np.load(fname) | |||||
| X = data["X"] | |||||
| X = [[emb.astype(np.float32)] for emb in X] | |||||
| pseudo_label = data["pseudo_label"].astype(int).tolist() if get_pseudo_label else None | |||||
| Y = data["Y"][:, :4].astype(int).tolist() | |||||
| Y = [tuple(y) for y in Y] | |||||
| return X, pseudo_label, Y | |||||
| if __name__ == "__main__": | |||||
| dataset = get_dataset("val.npz") | |||||
| @@ -0,0 +1,149 @@ | |||||
| import argparse | |||||
| import os.path as osp | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import optim | |||||
| from ablkit.data.evaluation import SymbolAccuracy | |||||
| from ablkit.reasoning import Reasoner | |||||
| from ablkit.utils import ABLLogger, print_log | |||||
| from models.nn import ConceptNet | |||||
| from models.bdd_nn import BDDNN | |||||
| from models.bdd_model import BDDABLModel | |||||
| from reasoning.bddkb import BDDKB | |||||
| from dataset.data_util import get_dataset | |||||
| from bridge import BDDBridge | |||||
| from metric import BDDReasoningMetric | |||||
| def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): | |||||
| pred_prob = data_example.pred_prob.T # nc x 1 | |||||
| pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2 | |||||
| cols = np.arange(len(candidates_idxs[0]))[None, :] | |||||
| corr_prob = pred_prob[cols, candidates_idxs] | |||||
| costs = -np.sum(np.log(corr_prob + 1e-6), axis=1) | |||||
| return costs | |||||
| def get_args(): | |||||
| parser = argparse.ArgumentParser(description="BDD-OIA example") | |||||
| parser.add_argument( | |||||
| "--no-cuda", action="store_true", default=False, help="disables CUDA training" | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--epochs", | |||||
| type=int, | |||||
| default=1, | |||||
| help="number of epochs in each learning loop iteration (default : 1)", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--lr", type=float, default=2e-3, help="base model learning rate (default : 0.002)" | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--batch-size", type=int, default=32, help="base model batch size (default : 32)" | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--loops", type=int, default=2, help="number of loop iterations (default : 2)" | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--segment_size", type=int, default=0.01, help="segment size (default : 0.01)" | |||||
| ) | |||||
| parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") | |||||
| parser.add_argument( | |||||
| "--max-revision", type=int, default=3, help="maximum revision in reasoner (default : 3)" | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--require-more-revision", | |||||
| type=int, | |||||
| default=3, | |||||
| help="require more revision in reasoner (default : 3)", | |||||
| ) | |||||
| args = parser.parse_args() | |||||
| return args | |||||
| def main(): | |||||
| args = get_args() | |||||
| # Build logger | |||||
| print_log("Abductive Learning on the BDD-OIA example.", logger="current") | |||||
| # -- Working with Data ------------------------------ | |||||
| print_log("Working with Data.", logger="current") | |||||
| train_data = get_dataset(fname="train.npz", get_pseudo_label=True) | |||||
| val_data = get_dataset(fname="val.npz", get_pseudo_label=True) | |||||
| test_data = get_dataset(fname="test.npz", get_pseudo_label=True) | |||||
| # -- Building the Learning Part --------------------- | |||||
| print_log("Building the Learning Part.", logger="current") | |||||
| # Build necessary components for BDDNN | |||||
| cls = ConceptNet() | |||||
| loss_fn = nn.BCEWithLogitsLoss() | |||||
| optimizer = optim.Adam(cls.parameters(), lr=args.lr) | |||||
| use_cuda = not args.no_cuda and torch.cuda.is_available() | |||||
| device = torch.device("cuda" if use_cuda else "cpu") | |||||
| scheduler = optim.lr_scheduler.OneCycleLR( | |||||
| optimizer, | |||||
| max_lr=args.lr, | |||||
| pct_start=0.15, | |||||
| epochs=args.loops, | |||||
| steps_per_epoch=int(1 / args.segment_size) + 1, | |||||
| ) | |||||
| # Build BDDNN | |||||
| base_model = BDDNN( | |||||
| cls, | |||||
| loss_fn, | |||||
| optimizer, | |||||
| scheduler=scheduler, | |||||
| device=device, | |||||
| batch_size=args.batch_size, | |||||
| num_epochs=args.epochs, | |||||
| ) | |||||
| # Build ABLModel | |||||
| model = BDDABLModel(base_model) | |||||
| # -- Building the Reasoning Part -------------------- | |||||
| print_log("Building the Reasoning Part.", logger="current") | |||||
| # Build knowledge base | |||||
| kb = BDDKB() | |||||
| # Create reasoner | |||||
| reasoner = Reasoner( | |||||
| kb, | |||||
| dist_func=multi_label_confidence_dist, | |||||
| max_revision=args.max_revision, | |||||
| require_more_revision=args.require_more_revision, | |||||
| ) | |||||
| # -- Building Evaluation Metrics -------------------- | |||||
| print_log("Building Evaluation Metrics.", logger="current") | |||||
| metric_list = [SymbolAccuracy(prefix="bdd_oia"), BDDReasoningMetric(kb=kb, prefix="bdd_oia")] | |||||
| # -- Bridging Learning and Reasoning ---------------- | |||||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||||
| bridge = BDDBridge(model, reasoner, metric_list) | |||||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||||
| log_dir = ABLLogger.get_current_instance().log_dir | |||||
| weights_dir = osp.join(log_dir, "weights") | |||||
| # Train and Test | |||||
| bridge.train( | |||||
| train_data=train_data, | |||||
| val_data=val_data, | |||||
| loops=args.loops, | |||||
| segment_size=args.segment_size, | |||||
| save_interval=args.save_interval, | |||||
| save_dir=weights_dir, | |||||
| ) | |||||
| bridge.test(test_data) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,27 @@ | |||||
| from typing import Optional | |||||
| from ablkit.reasoning import KBBase | |||||
| from ablkit.data import BaseMetric, ListData | |||||
| class BDDReasoningMetric(BaseMetric): | |||||
| def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | |||||
| super().__init__(prefix) | |||||
| self.kb = kb | |||||
| def process(self, data_examples: ListData) -> None: | |||||
| pred_pseudo_label_list = data_examples.pred_pseudo_label | |||||
| y_list = data_examples.Y | |||||
| x_list = data_examples.X | |||||
| for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): | |||||
| pred_y = self.kb.logic_forward( | |||||
| pred_pseudo_label, *(x,) if self.kb._num_args == 2 else () | |||||
| ) | |||||
| for py, yy in zip(pred_y, y): | |||||
| self.results.append(int(py == yy)) | |||||
| def compute_metrics(self) -> dict: | |||||
| results = self.results | |||||
| metrics = dict() | |||||
| metrics["reasoning_accuracy"] = sum(results) / len(results) | |||||
| return metrics | |||||
| @@ -0,0 +1,25 @@ | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| from ablkit.data import ListData | |||||
| from ablkit.learning import ABLModel | |||||
| from ablkit.utils import reform_list | |||||
| class BDDABLModel(ABLModel): | |||||
| def predict(self, data_examples: ListData) -> Dict: | |||||
| model = self.base_model | |||||
| data_X = data_examples.flatten("X") | |||||
| if hasattr(model, "predict_proba"): | |||||
| prob = model.predict_proba(X=data_X) | |||||
| label = np.where(prob > 0.5, 1, 0).astype(int) | |||||
| prob = reform_list(prob, data_examples.X) | |||||
| else: | |||||
| prob = None | |||||
| label = model.predict(X=data_X) | |||||
| label = reform_list(label, data_examples.X) | |||||
| data_examples.pred_idx = label | |||||
| data_examples.pred_prob = prob | |||||
| return {"label": label, "prob": prob} | |||||
| @@ -0,0 +1,93 @@ | |||||
| import logging | |||||
| from typing import Any, Callable, List, Optional | |||||
| import numpy | |||||
| import torch | |||||
| from torch.utils.data import DataLoader | |||||
| from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset | |||||
| from ablkit.utils.logger import print_log | |||||
| class MultiLabelClassificationDataset(ClassificationDataset): | |||||
| def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None): | |||||
| if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||||
| raise ValueError("X and Y should be of type list.") | |||||
| self.X = X | |||||
| self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss | |||||
| self.transform = transform | |||||
| class BDDNN(BasicNN): | |||||
| def predict( | |||||
| self, | |||||
| data_loader: Optional[DataLoader] = None, | |||||
| X: Optional[List[Any]] = None, | |||||
| ) -> numpy.ndarray: | |||||
| if data_loader is not None and X is not None: | |||||
| print_log( | |||||
| "Predict the class of input data in data_loader instead of X.", | |||||
| logger="current", | |||||
| level=logging.WARNING, | |||||
| ) | |||||
| if data_loader is None: | |||||
| dataset = PredictionDataset(X, self.test_transform) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size=self.batch_size, | |||||
| num_workers=self.num_workers, | |||||
| collate_fn=self.collate_fn, | |||||
| pin_memory=torch.cuda.is_available(), | |||||
| ) | |||||
| pred_probs = self._predict(data_loader).sigmoid() | |||||
| pred = torch.where(pred_probs > 0.5, 1, 0).int() | |||||
| return pred.cpu().numpy() | |||||
| def predict_proba( | |||||
| self, | |||||
| data_loader: Optional[DataLoader] = None, | |||||
| X: Optional[List[Any]] = None, | |||||
| ) -> numpy.ndarray: | |||||
| if data_loader is not None and X is not None: | |||||
| print_log( | |||||
| "Predict the class probability of input data in data_loader instead of X.", | |||||
| logger="current", | |||||
| level=logging.WARNING, | |||||
| ) | |||||
| if data_loader is None: | |||||
| dataset = PredictionDataset(X, self.test_transform) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size=self.batch_size, | |||||
| num_workers=self.num_workers, | |||||
| collate_fn=self.collate_fn, | |||||
| pin_memory=torch.cuda.is_available(), | |||||
| ) | |||||
| pred_probs = self._predict(data_loader).sigmoid() # B x NC | |||||
| return pred_probs.cpu().numpy() | |||||
| def _data_loader( | |||||
| self, | |||||
| X: Optional[List[Any]], | |||||
| y: Optional[List[int]] = None, | |||||
| shuffle: Optional[bool] = True, | |||||
| ) -> DataLoader: | |||||
| if X is None: | |||||
| raise ValueError("X should not be None.") | |||||
| if y is None: | |||||
| y = [0] * len(X) | |||||
| if not len(y) == len(X): | |||||
| raise ValueError("X and y should have equal length.") | |||||
| dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size=self.batch_size, | |||||
| shuffle=shuffle, | |||||
| num_workers=self.num_workers, | |||||
| collate_fn=self.collate_fn, | |||||
| pin_memory=torch.cuda.is_available(), | |||||
| ) | |||||
| return data_loader | |||||
| @@ -0,0 +1,24 @@ | |||||
| from torch import nn | |||||
| class SimpleNet(nn.Module): | |||||
| def __init__(self, num_features=2048, num_concepts=21): | |||||
| super(SimpleNet, self).__init__() | |||||
| self.fc = nn.Linear(num_features, num_concepts) | |||||
| def forward(self, x): | |||||
| return self.fc(x) | |||||
| class ConceptNet(nn.Module): | |||||
| def __init__(self, num_features=2048, num_concepts=21): | |||||
| super(ConceptNet, self).__init__() | |||||
| intermidate_dim = 256 | |||||
| self.fc = nn.Sequential( | |||||
| nn.Linear(num_features, intermidate_dim), | |||||
| nn.SiLU(), | |||||
| nn.Linear(intermidate_dim, num_concepts), | |||||
| ) | |||||
| def forward(self, x): | |||||
| return self.fc(x) | |||||
| @@ -0,0 +1,67 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| from ablkit.reasoning import KBBase | |||||
| class BDDKB(KBBase): | |||||
| def __init__(self, pseudo_label_list=None): | |||||
| if pseudo_label_list is None: | |||||
| pseudo_label_list = [0, 1] | |||||
| super().__init__(pseudo_label_list) | |||||
| def logic_forward(self, attrs): | |||||
| """ | |||||
| Abduction space | |||||
| (0, 1, 0, 0) 610812 | |||||
| (0, 1, 0, 1) 75012 | |||||
| (0, 1, 1, 0) 75012 | |||||
| (0, 1, 1, 1) 9212 | |||||
| (1, 0, 0, 0) 12996 | |||||
| (1, 0, 0, 1) 1596 | |||||
| (1, 0, 1, 0) 1596 | |||||
| (1, 0, 1, 1) 196 | |||||
| """ | |||||
| assert len(attrs) == 21 | |||||
| ( | |||||
| green_light, | |||||
| follow, | |||||
| road_clear, | |||||
| red_light, | |||||
| traffic_sign, | |||||
| car, | |||||
| person, | |||||
| rider, | |||||
| other_obstacle, | |||||
| left_lane, | |||||
| left_green_light, | |||||
| left_follow, | |||||
| no_left_lane, | |||||
| left_obstacle, | |||||
| left_solid_line, | |||||
| right_lane, | |||||
| right_green_light, | |||||
| right_follow, | |||||
| no_right_lane, | |||||
| right_obstacle, | |||||
| right_solid_line, | |||||
| ) = attrs | |||||
| illegal_return = (0, 0, 0, 0) | |||||
| if red_light == green_light == 1: | |||||
| return illegal_return | |||||
| obstacle = car or person or rider or other_obstacle | |||||
| if road_clear == obstacle: | |||||
| return illegal_return | |||||
| move_forward = green_light or follow or road_clear | |||||
| stop = red_light or traffic_sign or obstacle | |||||
| if stop: | |||||
| move_forward = 0 | |||||
| can_turn_left = left_lane or left_green_light or left_follow | |||||
| cannot_turn_left = no_left_lane or left_obstacle or left_solid_line | |||||
| turn_left = can_turn_left and int(not cannot_turn_left) | |||||
| can_turn_right = right_lane or right_green_light or right_follow | |||||
| cannot_turn_right = no_right_lane or right_obstacle or right_solid_line | |||||
| turn_right = can_turn_right and int(not cannot_turn_right) | |||||
| return move_forward, stop, turn_left, turn_right | |||||
| @@ -0,0 +1 @@ | |||||
| ablkit | |||||
| @@ -81,7 +81,7 @@ def main(): | |||||
| "--label-smoothing", | "--label-smoothing", | ||||
| type=float, | type=float, | ||||
| default=0.2, | default=0.2, | ||||
| help="label smoothing in cross entropy loss (default : 0.2)" | |||||
| help="label smoothing in cross entropy loss (default : 0.2)", | |||||
| ) | ) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | ||||
| @@ -138,7 +138,12 @@ def main(): | |||||
| # Build BasicNN | # Build BasicNN | ||||
| base_model = BasicNN( | base_model = BasicNN( | ||||
| cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, | |||||
| cls, | |||||
| loss_fn, | |||||
| optimizer, | |||||
| device=device, | |||||
| batch_size=args.batch_size, | |||||
| num_epochs=args.epochs, | |||||
| ) | ) | ||||
| # Build ABLModel | # Build ABLModel | ||||