| @@ -2,7 +2,7 @@ name: ABL-Package-CI | |||
| on: | |||
| push: | |||
| branches: [ main, Dev ] | |||
| branches: [ main ] | |||
| pull_request: | |||
| branches: [ main ] | |||
| @@ -1,6 +1,9 @@ | |||
| *.pyc | |||
| /results | |||
| raw/ | |||
| *.jpg | |||
| *.png | |||
| *.pk | |||
| *.pk | |||
| *.pth | |||
| *.json | |||
| *.ckpt | |||
| examples/results | |||
| raw/ | |||
| @@ -61,7 +61,7 @@ def filter_data(X, abduced_Z): | |||
| def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbose=-1): | |||
| def train(model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1): | |||
| train_X, train_Z, train_Y = train_data | |||
| test_X, test_Z, test_Y = test_data | |||
| @@ -70,7 +70,7 @@ def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbo | |||
| sample_num = len(train_X) | |||
| if verbose < 1: | |||
| verbose = epochs | |||
| verbose = loop_num | |||
| char_acc_flag = 1 | |||
| if train_Z == None: | |||
| @@ -81,14 +81,14 @@ def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbo | |||
| train_func = clocker(model.train) | |||
| abduce_func = clocker(abducer.batch_abduce) | |||
| for epoch_idx in range(epochs): | |||
| X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx) | |||
| for loop_idx in range(loop_num): | |||
| X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx) | |||
| preds_res = predict_func(X) | |||
| abduced_Z = abduce_func(preds_res, Y) | |||
| if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1): | |||
| if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1): | |||
| res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag) | |||
| INFO('epoch: ', epoch_idx + 1, ' ', res) | |||
| INFO('loop: ', loop_idx + 1, ' ', res) | |||
| finetune_X, finetune_Z = filter_data(X, abduced_Z) | |||
| if len(finetune_X) > 0: | |||
| @@ -11,23 +11,6 @@ | |||
| # ================================================================# | |||
| from itertools import chain | |||
| from sklearn.tree import DecisionTreeClassifier | |||
| from sklearn.model_selection import cross_val_score | |||
| from sklearn.svm import LinearSVC | |||
| from sklearn.pipeline import make_pipeline | |||
| from sklearn.preprocessing import StandardScaler | |||
| from sklearn.svm import SVC | |||
| from sklearn.gaussian_process import GaussianProcessClassifier | |||
| from sklearn.gaussian_process.kernels import RBF | |||
| import pickle as pk | |||
| import random | |||
| from sklearn.neighbors import KNeighborsClassifier | |||
| import numpy as np | |||
| def get_part_data(X, i): | |||
| return list(map(lambda x: x[i], X)) | |||
| @@ -84,87 +67,3 @@ class WABLBasicModel: | |||
| _data_Y, _ = merge_data(Y) | |||
| data_Y = list(map(lambda y: self.mapping[y], _data_Y)) | |||
| self.cls_list[0].fit(X=data_X, y=data_Y) | |||
| class DecisionTree(WABLBasicModel): | |||
| def __init__(self, code_len, label_lists, share=False): | |||
| self.code_len = code_len | |||
| self._set_label_lists(label_lists) | |||
| self.cls_list = [] | |||
| self.share = share | |||
| if share: | |||
| # 本质上是同一个分类器 | |||
| self.cls_list.append( | |||
| DecisionTreeClassifier(random_state=0, min_samples_leaf=3) | |||
| ) | |||
| self.cls_list = self.cls_list * self.code_len | |||
| else: | |||
| for _ in range(code_len): | |||
| self.cls_list.append( | |||
| DecisionTreeClassifier(random_state=0, min_samples_leaf=3) | |||
| ) | |||
| class KNN(WABLBasicModel): | |||
| def __init__(self, code_len, label_lists, share=False, k=3): | |||
| self.code_len = code_len | |||
| self._set_label_lists(label_lists) | |||
| self.cls_list = [] | |||
| self.share = share | |||
| if share: | |||
| # 本质上是同一个分类器 | |||
| self.cls_list.append(KNeighborsClassifier(n_neighbors=k)) | |||
| self.cls_list = self.cls_list * self.code_len | |||
| else: | |||
| for _ in range(code_len): | |||
| self.cls_list.append(KNeighborsClassifier(n_neighbors=k)) | |||
| class CNN(WABLBasicModel): | |||
| def __init__(self, base_model, code_len, label_lists, share=True): | |||
| assert share == True, "Not implemented" | |||
| label_lists = [sorted(list(set(label_list))) for label_list in label_lists] | |||
| self.label_lists = label_lists | |||
| self.code_len = code_len | |||
| self.cls_list = [] | |||
| self.share = share | |||
| if share: | |||
| self.cls_list.append(base_model) | |||
| def train(self, X, Y, n_epoch=100): | |||
| # self.label_lists = [] | |||
| if self.share: | |||
| # 因为是同一个分类器,所以只需要把数据放在一起,然后训练其中任意一个即可 | |||
| data_X, _ = merge_data(X) | |||
| data_Y, _ = merge_data(Y) | |||
| self.cls_list[0].fit(X=data_X, y=data_Y, n_epoch=n_epoch) | |||
| # self.label_lists = [sorted(list(set(data_Y)))] * self.code_len | |||
| else: | |||
| for i in range(self.code_len): | |||
| data_X = get_part_data(X, i) | |||
| data_Y = get_part_data(Y, i) | |||
| self.cls_list[i].fit(data_X, data_Y) | |||
| # self.label_lists.append(sorted(list(set(data_Y)))) | |||
| if __name__ == "__main__": | |||
| # data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk" | |||
| data_path = "datasets/generated_data/0_code_7_2_0.00.pk" | |||
| codes, data, labels = pk.load(open(data_path, "rb")) | |||
| cls = KNN(7, False, k=3) | |||
| cls.train(data, labels) | |||
| print(cls.valid(data, labels)) | |||
| for res in cls.predict_proba(data): | |||
| print(res) | |||
| break | |||
| for res in cls.predict(data): | |||
| print(res) | |||
| break | |||
| print("Trained") | |||
| @@ -0,0 +1,199 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 4, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import sys\n", | |||
| "\n", | |||
| "sys.path.append(\"../\")\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import HED_Abducer\n", | |||
| "from abl.abducer.kb import HED_prolog_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.nn import SymbolNet\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| "from abl.models.wabl_models import WABLBasicModel\n", | |||
| "\n", | |||
| "from datasets.hed.get_hed import get_hed, split_equation\n", | |||
| "from abl import framework_hed" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize logger\n", | |||
| "recorder = logger()" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Logic Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stderr", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n", | |||
| "abducer = HED_Abducer(kb)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Machine Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize necessary component for machine learning part\n", | |||
| "cls = SymbolNet(\n", | |||
| " num_classes=len(kb.pseudo_label_list),\n", | |||
| " image_size=(28, 28, 1),\n", | |||
| ")\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "criterion = nn.CrossEntropyLoss()\n", | |||
| "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Pretrain NN classifier\n", | |||
| "framework_hed.hed_pretrain(kb, cls, recorder)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize BasicModel\n", | |||
| "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", | |||
| "base_model = BasicModel(\n", | |||
| " cls,\n", | |||
| " criterion,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " save_interval=1,\n", | |||
| " save_dir=recorder.save_dir,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| " recorder=recorder,\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Use WABL model to join two parts" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "model = WABLBasicModel(base_model, kb.pseudo_label_list)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Dataset" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "total_train_data = get_hed(train=True)\n", | |||
| "train_data, val_data = split_equation(total_train_data, 3, 1)\n", | |||
| "test_data = get_hed(train=False)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Train and save" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)\n", | |||
| "framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)\n", | |||
| "\n", | |||
| "recorder.dump()" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "ABL", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| }, | |||
| "orig_nbformat": 4 | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||
| @@ -0,0 +1,184 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import sys\n", | |||
| "\n", | |||
| "sys.path.append(\"../\")\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import AbducerBase\n", | |||
| "from abl.abducer.kb import HWF_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.nn import SymbolNet\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| "from abl.models.wabl_models import WABLBasicModel\n", | |||
| "\n", | |||
| "from datasets.hwf.get_hwf import get_hwf\n", | |||
| "from abl import framework_hed" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize logger\n", | |||
| "recorder = logger()" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Logic Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "kb = HWF_KB(GKB_flag=True)\n", | |||
| "abducer = AbducerBase(kb)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Machine Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize necessary component for machine learning part\n", | |||
| "cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1))\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "criterion = nn.CrossEntropyLoss()\n", | |||
| "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize BasicModel\n", | |||
| "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", | |||
| "base_model = BasicModel(\n", | |||
| " cls,\n", | |||
| " criterion,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " save_interval=1,\n", | |||
| " save_dir=recorder.save_dir,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| " recorder=recorder,\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Use WABL model to join two parts" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize WABL model\n", | |||
| "# The main function of the WABL model is to serialize data and \n", | |||
| "# provide a unified interface for different machine learning models\n", | |||
| "model = WABLBasicModel(base_model, kb.pseudo_label_list)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Dataset" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Get training and testing data\n", | |||
| "train_data = get_hwf(train=True, get_pseudo_label=True)\n", | |||
| "test_data = get_hwf(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Train and save" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Train model\n", | |||
| "framework_hed.train(\n", | |||
| " model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n", | |||
| ")\n", | |||
| "\n", | |||
| "# Save results\n", | |||
| "recorder.dump()" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "ABL", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.13" | |||
| }, | |||
| "orig_nbformat": 4 | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||
| @@ -0,0 +1,190 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import sys\n", | |||
| "\n", | |||
| "sys.path.append(\"../\")\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import AbducerBase\n", | |||
| "from abl.abducer.kb import add_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.nn import LeNet5\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| "from abl.models.wabl_models import WABLBasicModel\n", | |||
| "\n", | |||
| "from datasets.mnist_add.get_mnist_add import get_mnist_add\n", | |||
| "from abl import framework_hed" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize logger\n", | |||
| "recorder = logger()" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Logic Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "kb = add_KB(GKB_flag=True)\n", | |||
| "abducer = AbducerBase(kb, dist_func=\"confidence\")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Machine Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize necessary component for machine learning part\n", | |||
| "cls = LeNet5(num_classes=len(kb.pseudo_label_list))\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "criterion = nn.CrossEntropyLoss()\n", | |||
| "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize BasicModel\n", | |||
| "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", | |||
| "base_model = BasicModel(\n", | |||
| " cls,\n", | |||
| " criterion,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " save_interval=1,\n", | |||
| " save_dir=recorder.save_dir,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| " recorder=recorder,\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Use WABL model to join two parts" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize WABL model\n", | |||
| "# The main function of the WABL model is to serialize data and \n", | |||
| "# provide a unified interface for different machine learning models\n", | |||
| "model = WABLBasicModel(base_model, kb.pseudo_label_list)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Dataset" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Get training and testing data\n", | |||
| "train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)\n", | |||
| "test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Train and save" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Train model\n", | |||
| "framework_hed.train(\n", | |||
| " model,\n", | |||
| " abducer,\n", | |||
| " (train_X, train_Z, train_Y),\n", | |||
| " (test_X, test_Z, test_Y),\n", | |||
| " loop_num=15,\n", | |||
| " sample_num=5000,\n", | |||
| " verbose=1,\n", | |||
| ")\n", | |||
| "\n", | |||
| "# Save results\n", | |||
| "recorder.dump()" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "ABL", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| }, | |||
| "orig_nbformat": 4 | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||