{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append(\"../\")\n", "\n", "import torch.nn as nn\n", "import torch\n", "\n", "from abl.abducer.abducer_base import AbducerBase\n", "from abl.abducer.kb import HWF_KB\n", "\n", "from abl.utils.plog import logger\n", "from abl.models.nn import SymbolNet\n", "from abl.models.basic_model import BasicModel\n", "from abl.models.wabl_models import WABLBasicModel\n", "\n", "from datasets.hwf.get_hwf import get_hwf\n", "from abl import framework_hed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize logger\n", "recorder = logger()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Logic Part" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize knowledge base and abducer\n", "kb = HWF_KB(GKB_flag=True)\n", "abducer = AbducerBase(kb)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Machine Learning Part" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize necessary component for machine learning part\n", "cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1))\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize BasicModel\n", "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", "base_model = BasicModel(\n", " cls,\n", " criterion,\n", " optimizer,\n", " device,\n", " save_interval=1,\n", " save_dir=recorder.save_dir,\n", " batch_size=32,\n", " num_epochs=1,\n", " recorder=recorder,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Use WABL model to join two parts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize WABL model\n", "# The main function of the WABL model is to serialize data and \n", "# provide a unified interface for different machine learning models\n", "model = WABLBasicModel(base_model, kb.pseudo_label_list)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get training and testing data\n", "train_data = get_hwf(train=True, get_pseudo_label=True)\n", "test_data = get_hwf(train=False, get_pseudo_label=True)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Train and save" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train model\n", "framework_hed.train(\n", " model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n", ")\n", "\n", "# Save results\n", "recorder.dump()" ] } ], "metadata": { "kernelspec": { "display_name": "ABL", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }