| @@ -53,7 +53,7 @@ For Linux users: | |||
| $ sudo apt-get install swi-prolog | |||
| ``` | |||
| For Windows and Mac users, please refer to the [Swi-Prolog Download Page](https://www.swi-prolog.org/Download.html). | |||
| For Windows and Mac users, please refer to the [Swi-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). | |||
| ## Examples | |||
| @@ -66,9 +66,9 @@ class BasicNN: | |||
| num_workers: int = 0, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| train_transform: Callable[..., Any] = None, | |||
| test_transform: Callable[..., Any] = None, | |||
| collate_fn: Callable[[List[Any]], Any] = None, | |||
| train_transform: Optional[Callable[..., Any]] = None, | |||
| test_transform: Optional[Callable[..., Any]] = None, | |||
| collate_fn: Optional[Callable[[List[Any]], Any]] = None, | |||
| ) -> None: | |||
| if not isinstance(model, torch.nn.Module): | |||
| raise TypeError("model must be an instance of torch.nn.Module") | |||
| @@ -471,7 +471,7 @@ class PrologKB(KBBase): | |||
| except (IndexError, ImportError): | |||
| print("A Prolog-based knowledge base is in use. Please install Swi-Prolog \ | |||
| using the command 'sudo apt-get install swi-prolog' for Linux users, \ | |||
| or download it from https://www.swi-prolog.org/Download.html for Windows and Mac users.") | |||
| or download it following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md for Windows and Mac users.") | |||
| self.prolog = pyswip.Prolog() | |||
| self.pl_file = pl_file | |||
| @@ -1,6 +1,6 @@ | |||
| from .cache import Cache, abl_cache | |||
| from .logger import ABLLogger, print_log | |||
| from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable | |||
| from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple | |||
| __all__ = [ | |||
| "Cache", | |||
| @@ -12,4 +12,5 @@ __all__ = [ | |||
| "reform_list", | |||
| "to_hashable", | |||
| "abl_cache", | |||
| "tab_data_to_tuple", | |||
| ] | |||
| @@ -154,4 +154,13 @@ def restore_from_hashable(x): | |||
| return [restore_from_hashable(item) for item in x] | |||
| return x | |||
| def tab_data_to_tuple(X, y, reasoning_result = 0): | |||
| ''' | |||
| Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result. | |||
| If X is None, return None. | |||
| ''' | |||
| if X is None: | |||
| return None | |||
| if len(X) != len(y): | |||
| raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) | |||
| return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) | |||
| @@ -31,4 +31,4 @@ For Linux users: | |||
| $ sudo apt-get install swi-prolog | |||
| For Windows and Mac users, please refer to the `Swi-Prolog Download Page <https://www.swi-prolog.org/Download.html>`_. | |||
| For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_. | |||
| @@ -51,7 +51,7 @@ For Linux users: | |||
| $ sudo apt-get install swi-prolog | |||
| For Windows and Mac users, please refer to the `Swi-Prolog Download Page <https://www.swi-prolog.org/Download.html>`_. | |||
| For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_. | |||
| References | |||
| ---------- | |||
| @@ -46,13 +46,20 @@ def main(): | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| total_train_data = get_dataset(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_dataset(train=False) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=4) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| @@ -75,6 +82,8 @@ def main(): | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| kb = HedKB() | |||
| @@ -82,14 +91,13 @@ def main(): | |||
| reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [ConsistencyMetric(kb=kb)] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = HedBridge(model, reasoner, metric_list) | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| @@ -113,12 +113,19 @@ def main(): | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HWF example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| @@ -140,6 +147,8 @@ def main(): | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.ground: | |||
| kb = HwfGroundKB() | |||
| @@ -152,14 +161,13 @@ def main(): | |||
| ) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Build logger | |||
| print_log("Abductive Learning on the HWF example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| @@ -78,11 +78,17 @@ def main(): | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = LeNet5(num_classes=10) | |||
| loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) | |||
| @@ -112,6 +118,8 @@ def main(): | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.prolog: | |||
| kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
| @@ -126,14 +134,13 @@ def main(): | |||
| ) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Build logger | |||
| print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| @@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge | |||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| from abl.learning import ABLModel | |||
| from abl.reasoning import Reasoner | |||
| from abl.utils import ABLLogger, confidence_dist, print_log | |||
| from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple | |||
| from get_dataset import load_and_preprocess_dataset, split_dataset | |||
| from kb import ZooKB | |||
| def transform_tab_data(X, y): | |||
| return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) | |||
| def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
| pred_prob = data_example.pred_prob | |||
| @@ -30,21 +28,31 @@ def main(): | |||
| "--loops", type=int, default=3, help="number of loop iterations (default : 3)" | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the ZOO example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| X, y = load_and_preprocess_dataset(dataset_id=62) | |||
| X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | |||
| label_data = transform_tab_data(X_label, y_label) | |||
| test_data = transform_tab_data(X_test, y_test) | |||
| train_data = transform_tab_data(X_unlabel, y_unlabel) | |||
| label_data = tab_data_to_tuple(X_label, y_label) | |||
| test_data = tab_data_to_tuple(X_test, y_test) | |||
| train_data = tab_data_to_tuple(X_unlabel, y_unlabel) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build base model | |||
| base_model = RandomForestClassifier() | |||
| # Build ABLModel | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| kb = ZooKB() | |||
| @@ -52,16 +60,17 @@ def main(): | |||
| reasoner = Reasoner(kb, dist_func=consitency) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] | |||
| # Build logger | |||
| print_log("Abductive Learning on the ZOO example.", logger="current") | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| ### Bridging learning and reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| # Performing training and testing | |||
| print_log("------- Use labeled data to pretrain the model -----------", logger="current") | |||
| base_model.fit(X_label, y_label) | |||
| @@ -27,7 +27,7 @@ | |||
| "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
| "from abl.learning import ABLModel\n", | |||
| "from abl.reasoning import Reasoner\n", | |||
| "from abl.utils import ABLLogger, confidence_dist, print_log\n", | |||
| "from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n", | |||
| "\n", | |||
| "from get_dataset import load_and_preprocess_dataset, split_dataset\n", | |||
| "from kb import ZooKB" | |||
| @@ -106,11 +106,9 @@ | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "def transform_tab_data(X, y):\n", | |||
| " return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", | |||
| "label_data = transform_tab_data(X_label, y_label)\n", | |||
| "test_data = transform_tab_data(X_test, y_test)\n", | |||
| "train_data = transform_tab_data(X_unlabel, y_unlabel)" | |||
| "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n", | |||
| "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n", | |||
| "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" | |||
| name = "abl" | |||
| version = "0.1.4" | |||
| authors = [ | |||
| { name="LAMDA 2023" }, | |||
| { name="LAMDA 2024" }, | |||
| ] | |||
| description = "Abductive learning package project" | |||
| readme = "README.md" | |||
| @@ -25,12 +25,12 @@ classifiers = [ | |||
| "Programming Language :: Python :: 3.9", | |||
| ] | |||
| dependencies = [ | |||
| "numpy", | |||
| "pyswip==0.2.9", | |||
| "torch", | |||
| "torchvision", | |||
| "zoopt", | |||
| "termcolor" | |||
| "numpy>=1.15.0", | |||
| "pyswip>=0.2.9", | |||
| "torch>=1.11.0", | |||
| "torchvision>=0.12.0", | |||
| "zoopt>=0.3.0", | |||
| "termcolor>=2.3.0" | |||
| ] | |||
| [project.urls] | |||
| @@ -1,6 +1,6 @@ | |||
| numpy | |||
| pyswip==0.2.9 | |||
| torch | |||
| torchvision | |||
| zoopt | |||
| termcolor | |||
| numpy>=1.15.0, | |||
| pyswip>=0.2.9, | |||
| torch>=1.11.0, | |||
| torchvision>=0.12.0, | |||
| zoopt>=0.3.0, | |||
| termcolor>=2.3.0 | |||