From 643dc64f57d295359dfdd8084a2bc0cd6f2e6183 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Sun, 8 Oct 2023 15:37:47 +0800 Subject: [PATCH] [MNT] change kb parameter name in examples --- examples/hwf/hwf_example.ipynb | 10 ++++----- examples/mnist_add/mnist_add_example.ipynb | 26 +++++++++++----------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index f9de708..b1ba550 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -49,12 +49,12 @@ " def __init__(\n", " self, \n", " pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", - " len_list=[1, 3, 5, 7],\n", - " GKB_flag=False,\n", + " prebuild_GKB=False,\n", + " GKB_len_list=[1, 3, 5, 7],\n", " max_err=1e-3,\n", " use_cache=True\n", " ):\n", - " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", + " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", "\n", " def _valid_candidate(self, formula):\n", " if len(formula) % 2 == 0:\n", @@ -74,7 +74,7 @@ " formula = [mapping[f] for f in formula]\n", " return eval(''.join(formula))\n", "\n", - "kb = HWF_KB(GKB_flag=True)\n", + "kb = HWF_KB(prebuild_GKB=True)\n", "abducer = ReasonerBase(kb, dist_func='confidence')" ] }, @@ -220,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index ade4e55..146bd88 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -17,12 +17,12 @@ "from abl.utils import ABLLogger\n", "\n", "from models.nn import LeNet5\n", - "from datasets.get_mnist_add import get_mnist_add" + "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -40,19 +40,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Initialize knowledge base and abducer\n", "class add_KB(KBBase):\n", - " def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):\n", - " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", + " def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", + " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", "\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", - "kb = add_KB(GKB_flag=True)\n", + "kb = add_KB(prebuild_GKB=True)\n", "\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", "abducer = ReasonerBase(kb, dist_func=\"confidence\")" @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -109,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -208,7 +208,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.13" }, "orig_nbformat": 4, "vscode": {