| @@ -49,12 +49,12 @@ | |||
| " def __init__(\n", | |||
| " self, \n", | |||
| " pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", | |||
| " len_list=[1, 3, 5, 7],\n", | |||
| " GKB_flag=False,\n", | |||
| " prebuild_GKB=False,\n", | |||
| " GKB_len_list=[1, 3, 5, 7],\n", | |||
| " max_err=1e-3,\n", | |||
| " use_cache=True\n", | |||
| " ):\n", | |||
| " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", | |||
| " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||
| "\n", | |||
| " def _valid_candidate(self, formula):\n", | |||
| " if len(formula) % 2 == 0:\n", | |||
| @@ -74,7 +74,7 @@ | |||
| " formula = [mapping[f] for f in formula]\n", | |||
| " return eval(''.join(formula))\n", | |||
| "\n", | |||
| "kb = HWF_KB(GKB_flag=True)\n", | |||
| "kb = HWF_KB(prebuild_GKB=True)\n", | |||
| "abducer = ReasonerBase(kb, dist_func='confidence')" | |||
| ] | |||
| }, | |||
| @@ -220,7 +220,7 @@ | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| "version": "3.8.13" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||
| @@ -2,7 +2,7 @@ | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 3, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -17,12 +17,12 @@ | |||
| "from abl.utils import ABLLogger\n", | |||
| "\n", | |||
| "from models.nn import LeNet5\n", | |||
| "from datasets.get_mnist_add import get_mnist_add" | |||
| "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 4, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -40,19 +40,19 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "class add_KB(KBBase):\n", | |||
| " def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):\n", | |||
| " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", | |||
| " def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", | |||
| " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||
| "\n", | |||
| " def logic_forward(self, nums):\n", | |||
| " return sum(nums)\n", | |||
| "\n", | |||
| "kb = add_KB(GKB_flag=True)\n", | |||
| "kb = add_KB(prebuild_GKB=True)\n", | |||
| "\n", | |||
| "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", | |||
| "abducer = ReasonerBase(kb, dist_func=\"confidence\")" | |||
| @@ -68,7 +68,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -81,7 +81,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 7, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -109,7 +109,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 8, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -129,7 +129,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 9, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -147,7 +147,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "execution_count": 10, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -208,7 +208,7 @@ | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| "version": "3.8.13" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||