From 6bd0ff66d67c8d53ea4dc0ca1ce8bb279d491446 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 31 Mar 2023 20:56:19 +0800 Subject: [PATCH] [ENH]run hed_example.ipynb after reformat examples --- examples/hed/datasets/get_hed.py | 6 +-- examples/hed/framework_hed.py | 87 +++----------------------------- examples/hed/hed_example.ipynb | 24 +++------ 3 files changed, 17 insertions(+), 100 deletions(-) diff --git a/examples/hed/datasets/get_hed.py b/examples/hed/datasets/get_hed.py index b317d94..5bac060 100644 --- a/examples/hed/datasets/get_hed.py +++ b/examples/hed/datasets/get_hed.py @@ -41,7 +41,7 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)): X = [] for label in labels: label_path = os.path.join( - "./datasets/hed/mnist_images", label + "./datasets/mnist_images", label ) img_path_list = os.listdir(label_path) for img_path in img_path_list: @@ -107,13 +107,13 @@ def get_hed(dataset="mnist", train=True): if dataset == "mnist": with open( - "./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", + "./datasets/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", "rb", ) as f: img_dataset = pickle.load(f) elif dataset == "random": with open( - "./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk", + "./datasets/random_equation_data_train_len_26_test_len_26_sys_2_.pk", "rb", ) as f: img_dataset = pickle.load(f) diff --git a/examples/hed/framework_hed.py b/examples/hed/framework_hed.py index 2196d24..8aa1ccd 100644 --- a/examples/hed/framework_hed.py +++ b/examples/hed/framework_hed.py @@ -10,93 +10,18 @@ # # ================================================================# -import pickle as pk import torch import torch.nn as nn import numpy as np import os -from .utils.plog import INFO, DEBUG, clocker -from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res +from abl.utils.plog import INFO +from abl.utils.utils import flatten, reform_idx +from abl.models.basic_model import BasicModel, BasicDataset -from .models.nn import SymbolNetAutoencoder -from .models.basic_model import BasicModel, BasicDataset - -import sys -sys.path.append("..") -from examples.datasets.hed.get_hed import get_pretrain_data - -def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): - result = {} - if char_acc_flag: - char_acc_num = 0 - char_num = 0 - for pred_z, z in zip(pred_Z, Z): - char_num += len(z) - for zidx in range(len(z)): - if pred_z[zidx] == z[zidx]: - char_acc_num += 1 - char_acc = char_acc_num / char_num - result["Character level accuracy"] = char_acc - - abl_acc_num = 0 - for pred_z, y in zip(pred_Z, Y): - if logic_forward(pred_z) == y: - abl_acc_num += 1 - abl_acc = abl_acc_num / len(Y) - result["ABL accuracy"] = abl_acc - - return result - - -def filter_data(X, abduced_Z): - finetune_Z = [] - finetune_X = [] - for x, abduced_z in zip(X, abduced_Z): - if len(abduced_z) > 0: - finetune_X.append(x) - finetune_Z.append(abduced_z) - return finetune_X, finetune_Z - - - -def train(model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1): - train_X, train_Z, train_Y = train_data - test_X, test_Z, test_Y = test_data - - # Set default parameters - if sample_num == -1: - sample_num = len(train_X) - - if verbose < 1: - verbose = loop_num - - char_acc_flag = 1 - if train_Z == None: - char_acc_flag = 0 - train_Z = [None] * len(train_X) - - predict_func = clocker(model.predict) - train_func = clocker(model.train) - abduce_func = clocker(abducer.batch_abduce) - - for loop_idx in range(loop_num): - X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx) - preds_res = predict_func(X) - abduced_Z = abduce_func(preds_res, Y) - - if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1): - res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag) - INFO('loop: ', loop_idx + 1, ' ', res) - - finetune_X, finetune_Z = filter_data(X, abduced_Z) - if len(finetune_X) > 0: - # model.valid(finetune_X, finetune_Z) - train_func(finetune_X, finetune_Z) - else: - INFO("lack of data, all abduced failed", len(finetune_X)) - - return res +from utils import gen_mappings, mapping_res, remapping_res +from models.nn import SymbolNetAutoencoder +from datasets.get_hed import get_pretrain_data def hed_pretrain(kb, cls, recorder): diff --git a/examples/hed/hed_example.ipynb b/examples/hed/hed_example.ipynb index 0479752..539b201 100644 --- a/examples/hed/hed_example.ipynb +++ b/examples/hed/hed_example.ipynb @@ -2,13 +2,13 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", - "sys.path.append(\"../\")\n", + "sys.path.append(\"../../\")\n", "\n", "import torch.nn as nn\n", "import torch\n", @@ -21,13 +21,13 @@ "from abl.models.wabl_models import WABLBasicModel\n", "\n", "from models.nn import SymbolNet\n", - "from datasets.hed.get_hed import get_hed, split_equation\n", - "from abl import framework_hed" + "from datasets.get_hed import get_hed, split_equation\n", + "import framework_hed" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -45,20 +45,12 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n" - ] - } - ], + "outputs": [], "source": [ "# Initialize knowledge base and abducer\n", - "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n", + "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n", "abducer = HED_Abducer(kb)" ] },