Browse Source

[ENH]run hed_example.ipynb after reformat examples

pull/3/head
Gao Enhao 3 years ago
parent
commit
6bd0ff66d6
3 changed files with 17 additions and 100 deletions
  1. +3
    -3
      examples/hed/datasets/get_hed.py
  2. +6
    -81
      examples/hed/framework_hed.py
  3. +8
    -16
      examples/hed/hed_example.ipynb

+ 3
- 3
examples/hed/datasets/get_hed.py View File

@@ -41,7 +41,7 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)):
X = [] X = []
for label in labels: for label in labels:
label_path = os.path.join( label_path = os.path.join(
"./datasets/hed/mnist_images", label
"./datasets/mnist_images", label
) )
img_path_list = os.listdir(label_path) img_path_list = os.listdir(label_path)
for img_path in img_path_list: for img_path in img_path_list:
@@ -107,13 +107,13 @@ def get_hed(dataset="mnist", train=True):


if dataset == "mnist": if dataset == "mnist":
with open( with open(
"./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk",
"./datasets/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk",
"rb", "rb",
) as f: ) as f:
img_dataset = pickle.load(f) img_dataset = pickle.load(f)
elif dataset == "random": elif dataset == "random":
with open( with open(
"./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk",
"./datasets/random_equation_data_train_len_26_test_len_26_sys_2_.pk",
"rb", "rb",
) as f: ) as f:
img_dataset = pickle.load(f) img_dataset = pickle.load(f)


+ 6
- 81
examples/hed/framework_hed.py View File

@@ -10,93 +10,18 @@
# #
# ================================================================# # ================================================================#


import pickle as pk
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import os import os


from .utils.plog import INFO, DEBUG, clocker
from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res
from abl.utils.plog import INFO
from abl.utils.utils import flatten, reform_idx
from abl.models.basic_model import BasicModel, BasicDataset


from .models.nn import SymbolNetAutoencoder
from .models.basic_model import BasicModel, BasicDataset

import sys
sys.path.append("..")
from examples.datasets.hed.get_hed import get_pretrain_data

def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
result = {}
if char_acc_flag:
char_acc_num = 0
char_num = 0
for pred_z, z in zip(pred_Z, Z):
char_num += len(z)
for zidx in range(len(z)):
if pred_z[zidx] == z[zidx]:
char_acc_num += 1
char_acc = char_acc_num / char_num
result["Character level accuracy"] = char_acc

abl_acc_num = 0
for pred_z, y in zip(pred_Z, Y):
if logic_forward(pred_z) == y:
abl_acc_num += 1
abl_acc = abl_acc_num / len(Y)
result["ABL accuracy"] = abl_acc

return result


def filter_data(X, abduced_Z):
finetune_Z = []
finetune_X = []
for x, abduced_z in zip(X, abduced_Z):
if len(abduced_z) > 0:
finetune_X.append(x)
finetune_Z.append(abduced_z)
return finetune_X, finetune_Z



def train(model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1):
train_X, train_Z, train_Y = train_data
test_X, test_Z, test_Y = test_data

# Set default parameters
if sample_num == -1:
sample_num = len(train_X)

if verbose < 1:
verbose = loop_num

char_acc_flag = 1
if train_Z == None:
char_acc_flag = 0
train_Z = [None] * len(train_X)

predict_func = clocker(model.predict)
train_func = clocker(model.train)
abduce_func = clocker(abducer.batch_abduce)

for loop_idx in range(loop_num):
X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx)
preds_res = predict_func(X)
abduced_Z = abduce_func(preds_res, Y)

if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1):
res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag)
INFO('loop: ', loop_idx + 1, ' ', res)

finetune_X, finetune_Z = filter_data(X, abduced_Z)
if len(finetune_X) > 0:
# model.valid(finetune_X, finetune_Z)
train_func(finetune_X, finetune_Z)
else:
INFO("lack of data, all abduced failed", len(finetune_X))

return res
from utils import gen_mappings, mapping_res, remapping_res
from models.nn import SymbolNetAutoencoder
from datasets.get_hed import get_pretrain_data




def hed_pretrain(kb, cls, recorder): def hed_pretrain(kb, cls, recorder):


+ 8
- 16
examples/hed/hed_example.ipynb View File

@@ -2,13 +2,13 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import sys\n", "import sys\n",
"\n", "\n",
"sys.path.append(\"../\")\n",
"sys.path.append(\"../../\")\n",
"\n", "\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import torch\n", "import torch\n",
@@ -21,13 +21,13 @@
"from abl.models.wabl_models import WABLBasicModel\n", "from abl.models.wabl_models import WABLBasicModel\n",
"\n", "\n",
"from models.nn import SymbolNet\n", "from models.nn import SymbolNet\n",
"from datasets.hed.get_hed import get_hed, split_equation\n",
"from abl import framework_hed"
"from datasets.get_hed import get_hed, split_equation\n",
"import framework_hed"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -45,20 +45,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n"
]
}
],
"outputs": [],
"source": [ "source": [
"# Initialize knowledge base and abducer\n", "# Initialize knowledge base and abducer\n",
"kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n",
"kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n",
"abducer = HED_Abducer(kb)" "abducer = HED_Abducer(kb)"
] ]
}, },


Loading…
Cancel
Save