Browse Source

[ENH] run mnist add example successfully after reformating the code

pull/3/head
Gao Enhao 3 years ago
parent
commit
98cadd3511
2 changed files with 34 additions and 28 deletions
  1. +1
    -1
      abl/bridge/simple_bridge.py
  2. +33
    -27
      examples/mnist_add/mnist_add_example.ipynb

+ 1
- 1
abl/bridge/simple_bridge.py View File

@@ -76,7 +76,7 @@ class SimpleBridge(BaseBridge):
min_loss = self.model.train(X, abduced_label)

print_log(
f"Epoch(train) [{epoch + 1}] [{seg_idx:3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}",
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}",
logger="current",
)



+ 33
- 27
examples/mnist_add/mnist_add_example.ipynb View File

@@ -9,16 +9,15 @@
"import torch.nn as nn\n",
"import torch\n",
"\n",
"from abl.reasoning.reasoner import ReasonerBase\n",
"from abl.reasoning.kb import KBBase, prolog_KB\n",
"from abl.reasoning import ReasonerBase, KBBase\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.learning.basic_nn import BasicNN\n",
"from abl.learning.abl_model import ABLModel\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, ABLMetric\n",
"from abl.utils import ABLLogger\n",
"\n",
"from models.nn import LeNet5\n",
"from datasets.get_mnist_add import get_mnist_add\n",
"from abl import framework"
"from datasets.get_mnist_add import get_mnist_add"
]
},
{
@@ -28,7 +27,7 @@
"outputs": [],
"source": [
"# Initialize logger\n",
"recorder = logger()"
"logger = ABLLogger.get_instance(\"abl\")"
]
},
{
@@ -94,10 +93,9 @@
" optimizer,\n",
" device,\n",
" save_interval=1,\n",
" save_dir=recorder.save_dir,\n",
" save_dir=logger.save_dir,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
" recorder=recorder,\n",
")"
]
},
@@ -118,7 +116,25 @@
"# Initialize ABL model\n",
"# The main function of the ABL model is to serialize data and \n",
"# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model, kb.pseudo_label_list)"
"model = ABLModel(base_model)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metric"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add metric\n",
"metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]"
]
},
{
@@ -145,7 +161,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and save"
"### Bridge Machine Learning and Logic Reasoning"
]
},
{
@@ -154,18 +170,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Train model\n",
"model = framework.train(\n",
" model,\n",
" abducer,\n",
" train_data,\n",
" epochs=5,\n",
" sample=12000,\n",
" verbose=1,\n",
")\n",
"\n",
"# Save results\n",
"recorder.dump()"
"bridge = SimpleBridge(model, abducer, metric)"
]
},
{
@@ -173,7 +178,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### TODO: Test"
"### Train and Test"
]
},
{
@@ -182,7 +187,8 @@
"metadata": {},
"outputs": [],
"source": [
"framework.test(model, abducer, test_data)"
"bridge.train(train_data, epochs=5, batch_size=10000)\n",
"bridge.test(test_data)"
]
}
],
@@ -202,7 +208,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.16"
},
"orig_nbformat": 4,
"vscode": {


Loading…
Cancel
Save