| @@ -41,7 +41,7 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||||
| X = [] | X = [] | ||||
| for label in labels: | for label in labels: | ||||
| label_path = os.path.join( | label_path = os.path.join( | ||||
| "./datasets/hed/mnist_images", label | |||||
| "./datasets/mnist_images", label | |||||
| ) | ) | ||||
| img_path_list = os.listdir(label_path) | img_path_list = os.listdir(label_path) | ||||
| for img_path in img_path_list: | for img_path in img_path_list: | ||||
| @@ -107,13 +107,13 @@ def get_hed(dataset="mnist", train=True): | |||||
| if dataset == "mnist": | if dataset == "mnist": | ||||
| with open( | with open( | ||||
| "./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||||
| "./datasets/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||||
| "rb", | "rb", | ||||
| ) as f: | ) as f: | ||||
| img_dataset = pickle.load(f) | img_dataset = pickle.load(f) | ||||
| elif dataset == "random": | elif dataset == "random": | ||||
| with open( | with open( | ||||
| "./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||||
| "./datasets/random_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||||
| "rb", | "rb", | ||||
| ) as f: | ) as f: | ||||
| img_dataset = pickle.load(f) | img_dataset = pickle.load(f) | ||||
| @@ -10,93 +10,18 @@ | |||||
| # | # | ||||
| # ================================================================# | # ================================================================# | ||||
| import pickle as pk | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import numpy as np | import numpy as np | ||||
| import os | import os | ||||
| from .utils.plog import INFO, DEBUG, clocker | |||||
| from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res | |||||
| from abl.utils.plog import INFO | |||||
| from abl.utils.utils import flatten, reform_idx | |||||
| from abl.models.basic_model import BasicModel, BasicDataset | |||||
| from .models.nn import SymbolNetAutoencoder | |||||
| from .models.basic_model import BasicModel, BasicDataset | |||||
| import sys | |||||
| sys.path.append("..") | |||||
| from examples.datasets.hed.get_hed import get_pretrain_data | |||||
| def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): | |||||
| result = {} | |||||
| if char_acc_flag: | |||||
| char_acc_num = 0 | |||||
| char_num = 0 | |||||
| for pred_z, z in zip(pred_Z, Z): | |||||
| char_num += len(z) | |||||
| for zidx in range(len(z)): | |||||
| if pred_z[zidx] == z[zidx]: | |||||
| char_acc_num += 1 | |||||
| char_acc = char_acc_num / char_num | |||||
| result["Character level accuracy"] = char_acc | |||||
| abl_acc_num = 0 | |||||
| for pred_z, y in zip(pred_Z, Y): | |||||
| if logic_forward(pred_z) == y: | |||||
| abl_acc_num += 1 | |||||
| abl_acc = abl_acc_num / len(Y) | |||||
| result["ABL accuracy"] = abl_acc | |||||
| return result | |||||
| def filter_data(X, abduced_Z): | |||||
| finetune_Z = [] | |||||
| finetune_X = [] | |||||
| for x, abduced_z in zip(X, abduced_Z): | |||||
| if len(abduced_z) > 0: | |||||
| finetune_X.append(x) | |||||
| finetune_Z.append(abduced_z) | |||||
| return finetune_X, finetune_Z | |||||
| def train(model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1): | |||||
| train_X, train_Z, train_Y = train_data | |||||
| test_X, test_Z, test_Y = test_data | |||||
| # Set default parameters | |||||
| if sample_num == -1: | |||||
| sample_num = len(train_X) | |||||
| if verbose < 1: | |||||
| verbose = loop_num | |||||
| char_acc_flag = 1 | |||||
| if train_Z == None: | |||||
| char_acc_flag = 0 | |||||
| train_Z = [None] * len(train_X) | |||||
| predict_func = clocker(model.predict) | |||||
| train_func = clocker(model.train) | |||||
| abduce_func = clocker(abducer.batch_abduce) | |||||
| for loop_idx in range(loop_num): | |||||
| X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx) | |||||
| preds_res = predict_func(X) | |||||
| abduced_Z = abduce_func(preds_res, Y) | |||||
| if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1): | |||||
| res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag) | |||||
| INFO('loop: ', loop_idx + 1, ' ', res) | |||||
| finetune_X, finetune_Z = filter_data(X, abduced_Z) | |||||
| if len(finetune_X) > 0: | |||||
| # model.valid(finetune_X, finetune_Z) | |||||
| train_func(finetune_X, finetune_Z) | |||||
| else: | |||||
| INFO("lack of data, all abduced failed", len(finetune_X)) | |||||
| return res | |||||
| from utils import gen_mappings, mapping_res, remapping_res | |||||
| from models.nn import SymbolNetAutoencoder | |||||
| from datasets.get_hed import get_pretrain_data | |||||
| def hed_pretrain(kb, cls, recorder): | def hed_pretrain(kb, cls, recorder): | ||||
| @@ -2,13 +2,13 @@ | |||||
| "cells": [ | "cells": [ | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 4, | |||||
| "execution_count": null, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| "import sys\n", | "import sys\n", | ||||
| "\n", | "\n", | ||||
| "sys.path.append(\"../\")\n", | |||||
| "sys.path.append(\"../../\")\n", | |||||
| "\n", | "\n", | ||||
| "import torch.nn as nn\n", | "import torch.nn as nn\n", | ||||
| "import torch\n", | "import torch\n", | ||||
| @@ -21,13 +21,13 @@ | |||||
| "from abl.models.wabl_models import WABLBasicModel\n", | "from abl.models.wabl_models import WABLBasicModel\n", | ||||
| "\n", | "\n", | ||||
| "from models.nn import SymbolNet\n", | "from models.nn import SymbolNet\n", | ||||
| "from datasets.hed.get_hed import get_hed, split_equation\n", | |||||
| "from abl import framework_hed" | |||||
| "from datasets.get_hed import get_hed, split_equation\n", | |||||
| "import framework_hed" | |||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 5, | |||||
| "execution_count": null, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| @@ -45,20 +45,12 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 6, | |||||
| "execution_count": null, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [ | |||||
| { | |||||
| "name": "stderr", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n" | |||||
| ] | |||||
| } | |||||
| ], | |||||
| "outputs": [], | |||||
| "source": [ | "source": [ | ||||
| "# Initialize knowledge base and abducer\n", | "# Initialize knowledge base and abducer\n", | ||||
| "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n", | |||||
| "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n", | |||||
| "abducer = HED_Abducer(kb)" | "abducer = HED_Abducer(kb)" | ||||
| ] | ] | ||||
| }, | }, | ||||