|
|
|
@@ -9,16 +9,15 @@ |
|
|
|
"import torch.nn as nn\n", |
|
|
|
"import torch\n", |
|
|
|
"\n", |
|
|
|
"from abl.reasoning.reasoner import ReasonerBase\n", |
|
|
|
"from abl.reasoning.kb import KBBase, prolog_KB\n", |
|
|
|
"from abl.reasoning import ReasonerBase, KBBase\n", |
|
|
|
"\n", |
|
|
|
"from abl.utils.plog import logger\n", |
|
|
|
"from abl.learning.basic_nn import BasicNN\n", |
|
|
|
"from abl.learning.abl_model import ABLModel\n", |
|
|
|
"from abl.learning import BasicNN, ABLModel\n", |
|
|
|
"from abl.bridge import SimpleBridge\n", |
|
|
|
"from abl.evaluation import SymbolMetric, ABLMetric\n", |
|
|
|
"from abl.utils import ABLLogger\n", |
|
|
|
"\n", |
|
|
|
"from models.nn import LeNet5\n", |
|
|
|
"from datasets.get_mnist_add import get_mnist_add\n", |
|
|
|
"from abl import framework" |
|
|
|
"from datasets.get_mnist_add import get_mnist_add" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -28,7 +27,7 @@ |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Initialize logger\n", |
|
|
|
"recorder = logger()" |
|
|
|
"logger = ABLLogger.get_instance(\"abl\")" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -94,10 +93,9 @@ |
|
|
|
" optimizer,\n", |
|
|
|
" device,\n", |
|
|
|
" save_interval=1,\n", |
|
|
|
" save_dir=recorder.save_dir,\n", |
|
|
|
" save_dir=logger.save_dir,\n", |
|
|
|
" batch_size=32,\n", |
|
|
|
" num_epochs=1,\n", |
|
|
|
" recorder=recorder,\n", |
|
|
|
")" |
|
|
|
] |
|
|
|
}, |
|
|
|
@@ -118,7 +116,25 @@ |
|
|
|
"# Initialize ABL model\n", |
|
|
|
"# The main function of the ABL model is to serialize data and \n", |
|
|
|
"# provide a unified interface for different machine learning models\n", |
|
|
|
"model = ABLModel(base_model, kb.pseudo_label_list)" |
|
|
|
"model = ABLModel(base_model)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"attachments": {}, |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Metric" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Add metric\n", |
|
|
|
"metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -145,7 +161,7 @@ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Train and save" |
|
|
|
"### Bridge Machine Learning and Logic Reasoning" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -154,18 +170,7 @@ |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Train model\n", |
|
|
|
"model = framework.train(\n", |
|
|
|
" model,\n", |
|
|
|
" abducer,\n", |
|
|
|
" train_data,\n", |
|
|
|
" epochs=5,\n", |
|
|
|
" sample=12000,\n", |
|
|
|
" verbose=1,\n", |
|
|
|
")\n", |
|
|
|
"\n", |
|
|
|
"# Save results\n", |
|
|
|
"recorder.dump()" |
|
|
|
"bridge = SimpleBridge(model, abducer, metric)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -173,7 +178,7 @@ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### TODO: Test" |
|
|
|
"### Train and Test" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -182,7 +187,8 @@ |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"framework.test(model, abducer, test_data)" |
|
|
|
"bridge.train(train_data, epochs=5, batch_size=10000)\n", |
|
|
|
"bridge.test(test_data)" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
@@ -202,7 +208,7 @@ |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.8.13" |
|
|
|
"version": "3.8.16" |
|
|
|
}, |
|
|
|
"orig_nbformat": 4, |
|
|
|
"vscode": { |
|
|
|
|