Browse Source

Update framework.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
dc1b669ab2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 102 additions and 59 deletions
  1. +102
    -59
      framework.py

+ 102
- 59
framework.py View File

@@ -17,17 +17,17 @@ import numpy as np
from utils.plog import INFO, DEBUG, clocker

@clocker
def block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx):
def block_sample(X_bak, C_bak, sample_num, epoch_idx):
part_num = (len(X_bak) // sample_num)
if part_num == 0:
part_num = 1
seg_idx = epoch_idx % part_num
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak))
X = X_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
# Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
C = C_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]

return X, Y, C
return X, C

def get_taglist(self, Y):
tmp = [[str(x) for x in label] for label in Y]
@@ -35,49 +35,58 @@ def get_taglist(self, Y):
return tmp

@clocker
def result_statistics(pseudo_Y, Y, abduced_Y):

abd_err_num = 0
abd_char_num = 0
abd_char_acc = 0
abd_failed = 0
word_err_num = 0
def result_statistics(C, pseudo_Y, logic_forward):
abl_acc = 0
for tidx, (c, pseudo_y) in enumerate(zip(C, pseudo_Y)):
if(logic_forward(pseudo_y) == c):
abl_acc += 1
ori_char_num = 0
ori_char_acc = 0
return abl_acc / len(C)

for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)):
pseudo_y = pseudo_y
if sum(abduced_y != y) != 0:
abd_err_num += 1
if abduced_y is not None:
abd_char_num += len(y)
abd_char_acc += sum(abduced_y == y)
else:
abd_failed += 1
# def result_statistics(pseudo_Y, Y, abduced_Y):

ori_char_num += len(pseudo_y)
ori_char_acc += sum(pseudo_y == y)
# abd_err_num = 0
# abd_char_num = 0
# abd_char_acc = 0
# abd_failed = 0
# word_err_num = 0
# ori_char_num = 0
# ori_char_acc = 0

# for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)):
# pseudo_y = pseudo_y
# if sum(abduced_y != y) != 0:
# abd_err_num += 1
# if abduced_y is not None:
# abd_char_num += len(y)
# abd_char_acc += sum(abduced_y == y)
# else:
# abd_failed += 1

# ori_char_num += len(pseudo_y)
# ori_char_acc += sum(pseudo_y == y)
if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0:
INFO(pseudo_y, y, abduced_y)
pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb"))
# if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0:
# INFO(pseudo_y, y, abduced_y)
# pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb"))

if sum(pseudo_y != y) != 0:
word_err_num += 1
# if sum(pseudo_y != y) != 0:
# word_err_num += 1

INFO("")
INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y))
INFO("Abd char level accuracy:", abd_char_acc / abd_char_num)
INFO("Ori char level accuracy:", ori_char_acc / ori_char_num)
INFO("")
# INFO("")
# INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y))
# INFO("Abd char level accuracy:", abd_char_acc / abd_char_num)
# INFO("Ori char level accuracy:", ori_char_acc / ori_char_num)
# INFO("")

result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num,
"total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc,
"total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc,
"total_abd_failed": abd_failed}
# result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num,
# "total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc,
# "total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc,
# "total_abd_failed": abd_failed}

return result
# return result

@clocker
def filter_data(X, abduced_Y):
@@ -106,7 +115,7 @@ def is_all_sublabel_exist(labels, std_label_list):
def pretrain(model, X, Y):
pass

def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True):
def train(model, abducer, X, C, logic_forward, epochs = 10, sample_num = -1, verbose = -1):
# Set default parameters
if sample_num == -1:
sample_num = len(X)
@@ -114,42 +123,76 @@ def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose
if verbose < 1:
verbose = epochs

if C is None:
C = [None] * len(X)

# Set function running time recorder
valid_func = clocker(model.valid)
predict_func = clocker(model.predict)
train_func = clocker(model.train)

abduce_func = clocker(abducer.batch_abduce)

X_bak = X
Y_bak = Y
C_bak = C
epochs = 50
# Abductive learning train process
res = {}
for epoch_idx in range(epochs):
X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx)
X, C = block_sample(X_bak, C_bak, sample_num, epoch_idx)
preds_res = predict_func(X)
abl_acc = result_statistics(C, preds_res['cls'], logic_forward)
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc)
abduced_Y = abduce_func(preds_res, C)
finetune_X, finetune_Y = filter_data(X, abduced_Y)
score, score_list = valid_func(X, Y)
if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
res = result_statistics(preds_res["cls"], Y, abduced_Y)
INFO(res)

if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)):
INFO("There is some sub label missing", len(finetune_Y))
break

finetune_X, finetune_Y = filter_data(X, abduced_Y)
if len(finetune_X) > 0:
train_func(finetune_X, finetune_Y)#, n_epoch = 10)
else:
INFO("lack of data, all abduced failed", len(finetune_X))
return res
#return ret
return abl_acc

# def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True):
# # Set default parameters
# if sample_num == -1:
# sample_num = len(X)

# if verbose < 1:
# verbose = epochs

# if C is None:
# C = [None] * len(X)

# # Set function running time recorder
# valid_func = clocker(model.valid)
# predict_func = clocker(model.predict)
# train_func = clocker(model.train)

# abduce_func = clocker(abducer.batch_abduce)

# X_bak = X
# Y_bak = Y
# C_bak = C

# # Abductive learning train process
# res = {}
# for epoch_idx in range(epochs):
# X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx)
# preds_res = predict_func(X)
# abduced_Y = abduce_func(preds_res, C)
# finetune_X, finetune_Y = filter_data(X, abduced_Y)
# score, score_list = valid_func(X, Y)
# if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
# res = result_statistics(preds_res["cls"], Y, abduced_Y)
# INFO(res)

# if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)):
# INFO("There is some sub label missing", len(finetune_Y))
# break

# if len(finetune_X) > 0:
# train_func(finetune_X, finetune_Y)#, n_epoch = 10)
# else:
# INFO("lack of data, all abduced failed", len(finetune_X))
# return res

if __name__ == "__main__":
pass

Loading…
Cancel
Save