{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append(\"../\")\n", "\n", "import torch.nn as nn\n", "import torch\n", "\n", "from abl.abducer.abducer_base import HED_Abducer\n", "from abl.abducer.kb import HED_prolog_KB\n", "\n", "from abl.utils.plog import logger\n", "from abl.models.nn import SymbolNet\n", "from abl.models.basic_model import BasicModel\n", "from abl.models.wabl_models import WABLBasicModel\n", "\n", "from datasets.hed.get_hed import get_hed, split_equation\n", "from abl import framework_hed" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Initialize logger\n", "recorder = logger()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Logic Part" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n" ] } ], "source": [ "# Initialize knowledge base and abducer\n", "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n", "abducer = HED_Abducer(kb)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Machine Learning Part" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize necessary component for machine learning part\n", "cls = SymbolNet(\n", " num_classes=len(kb.pseudo_label_list),\n", " image_size=(28, 28, 1),\n", ")\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Pretrain NN classifier\n", "framework_hed.hed_pretrain(kb, cls, recorder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize BasicModel\n", "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", "base_model = BasicModel(\n", " cls,\n", " criterion,\n", " optimizer,\n", " device,\n", " save_interval=1,\n", " save_dir=recorder.save_dir,\n", " batch_size=32,\n", " num_epochs=1,\n", " recorder=recorder,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Use WABL model to join two parts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = WABLBasicModel(base_model, kb.pseudo_label_list)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "total_train_data = get_hed(train=True)\n", "train_data, val_data = split_equation(total_train_data, 3, 1)\n", "test_data = get_hed(train=False)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Train and save" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)\n", "framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)\n", "\n", "recorder.dump()" ] } ], "metadata": { "kernelspec": { "display_name": "ABL", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }