Browse Source

[MNT] change kb parameter name in examples

pull/3/head
troyyyyy 2 years ago
parent
commit
643dc64f57
2 changed files with 18 additions and 18 deletions
  1. +5
    -5
      examples/hwf/hwf_example.ipynb
  2. +13
    -13
      examples/mnist_add/mnist_add_example.ipynb

+ 5
- 5
examples/hwf/hwf_example.ipynb View File

@@ -49,12 +49,12 @@
" def __init__(\n",
" self, \n",
" pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n",
" len_list=[1, 3, 5, 7],\n",
" GKB_flag=False,\n",
" prebuild_GKB=False,\n",
" GKB_len_list=[1, 3, 5, 7],\n",
" max_err=1e-3,\n",
" use_cache=True\n",
" ):\n",
" super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n",
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n",
"\n",
" def _valid_candidate(self, formula):\n",
" if len(formula) % 2 == 0:\n",
@@ -74,7 +74,7 @@
" formula = [mapping[f] for f in formula]\n",
" return eval(''.join(formula))\n",
"\n",
"kb = HWF_KB(GKB_flag=True)\n",
"kb = HWF_KB(prebuild_GKB=True)\n",
"abducer = ReasonerBase(kb, dist_func='confidence')"
]
},
@@ -220,7 +220,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {


+ 13
- 13
examples/mnist_add/mnist_add_example.ipynb View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -17,12 +17,12 @@
"from abl.utils import ABLLogger\n",
"\n",
"from models.nn import LeNet5\n",
"from datasets.get_mnist_add import get_mnist_add"
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -40,19 +40,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"class add_KB(KBBase):\n",
" def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n",
" def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n",
"\n",
" def logic_forward(self, nums):\n",
" return sum(nums)\n",
"\n",
"kb = add_KB(GKB_flag=True)\n",
"kb = add_KB(prebuild_GKB=True)\n",
"\n",
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n",
"abducer = ReasonerBase(kb, dist_func=\"confidence\")"
@@ -68,7 +68,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -81,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -109,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -129,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -147,7 +147,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -208,7 +208,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {


Loading…
Cancel
Save