diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 79ed94f..806733e 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -76,7 +76,7 @@ class SimpleBridge(BaseBridge): min_loss = self.model.train(X, abduced_label) print_log( - f"Epoch(train) [{epoch + 1}] [{seg_idx:3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}", + f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}", logger="current", ) diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index f4cadff..ade4e55 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -9,16 +9,15 @@ "import torch.nn as nn\n", "import torch\n", "\n", - "from abl.reasoning.reasoner import ReasonerBase\n", - "from abl.reasoning.kb import KBBase, prolog_KB\n", + "from abl.reasoning import ReasonerBase, KBBase\n", "\n", - "from abl.utils.plog import logger\n", - "from abl.learning.basic_nn import BasicNN\n", - "from abl.learning.abl_model import ABLModel\n", + "from abl.learning import BasicNN, ABLModel\n", + "from abl.bridge import SimpleBridge\n", + "from abl.evaluation import SymbolMetric, ABLMetric\n", + "from abl.utils import ABLLogger\n", "\n", "from models.nn import LeNet5\n", - "from datasets.get_mnist_add import get_mnist_add\n", - "from abl import framework" + "from datasets.get_mnist_add import get_mnist_add" ] }, { @@ -28,7 +27,7 @@ "outputs": [], "source": [ "# Initialize logger\n", - "recorder = logger()" + "logger = ABLLogger.get_instance(\"abl\")" ] }, { @@ -94,10 +93,9 @@ " optimizer,\n", " device,\n", " save_interval=1,\n", - " save_dir=recorder.save_dir,\n", + " save_dir=logger.save_dir,\n", " batch_size=32,\n", " num_epochs=1,\n", - " recorder=recorder,\n", ")" ] }, @@ -118,7 +116,25 @@ "# Initialize ABL model\n", "# The main function of the ABL model is to serialize data and \n", "# provide a unified interface for different machine learning models\n", - "model = ABLModel(base_model, kb.pseudo_label_list)" + "model = ABLModel(base_model)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Metric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add metric\n", + "metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" ] }, { @@ -145,7 +161,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Train and save" + "### Bridge Machine Learning and Logic Reasoning" ] }, { @@ -154,18 +170,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Train model\n", - "model = framework.train(\n", - " model,\n", - " abducer,\n", - " train_data,\n", - " epochs=5,\n", - " sample=12000,\n", - " verbose=1,\n", - ")\n", - "\n", - "# Save results\n", - "recorder.dump()" + "bridge = SimpleBridge(model, abducer, metric)" ] }, { @@ -173,7 +178,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### TODO: Test" + "### Train and Test" ] }, { @@ -182,7 +187,8 @@ "metadata": {}, "outputs": [], "source": [ - "framework.test(model, abducer, test_data)" + "bridge.train(train_data, epochs=5, batch_size=10000)\n", + "bridge.test(test_data)" ] } ], @@ -202,7 +208,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" }, "orig_nbformat": 4, "vscode": {