From dc1b669ab225d40711b2c1988fa37297399f36eb Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Thu, 17 Nov 2022 10:37:01 +0800 Subject: [PATCH] Update framework.py --- framework.py | 161 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 102 insertions(+), 59 deletions(-) diff --git a/framework.py b/framework.py index b41a334..27c2dba 100644 --- a/framework.py +++ b/framework.py @@ -17,17 +17,17 @@ import numpy as np from utils.plog import INFO, DEBUG, clocker @clocker -def block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx): +def block_sample(X_bak, C_bak, sample_num, epoch_idx): part_num = (len(X_bak) // sample_num) if part_num == 0: part_num = 1 seg_idx = epoch_idx % part_num INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak)) X = X_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] - Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] + # Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] C = C_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] - return X, Y, C + return X, C def get_taglist(self, Y): tmp = [[str(x) for x in label] for label in Y] @@ -35,49 +35,58 @@ def get_taglist(self, Y): return tmp @clocker -def result_statistics(pseudo_Y, Y, abduced_Y): - - abd_err_num = 0 - abd_char_num = 0 - abd_char_acc = 0 - abd_failed = 0 - word_err_num = 0 +def result_statistics(C, pseudo_Y, logic_forward): + abl_acc = 0 + for tidx, (c, pseudo_y) in enumerate(zip(C, pseudo_Y)): + if(logic_forward(pseudo_y) == c): + abl_acc += 1 + - ori_char_num = 0 - ori_char_acc = 0 + return abl_acc / len(C) - for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)): - pseudo_y = pseudo_y - if sum(abduced_y != y) != 0: - abd_err_num += 1 - if abduced_y is not None: - abd_char_num += len(y) - abd_char_acc += sum(abduced_y == y) - else: - abd_failed += 1 +# def result_statistics(pseudo_Y, Y, abduced_Y): - ori_char_num += len(pseudo_y) - ori_char_acc += sum(pseudo_y == y) +# abd_err_num = 0 +# abd_char_num = 0 +# abd_char_acc = 0 +# abd_failed = 0 +# word_err_num = 0 + +# ori_char_num = 0 +# ori_char_acc = 0 + +# for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)): +# pseudo_y = pseudo_y +# if sum(abduced_y != y) != 0: +# abd_err_num += 1 +# if abduced_y is not None: +# abd_char_num += len(y) +# abd_char_acc += sum(abduced_y == y) +# else: +# abd_failed += 1 + +# ori_char_num += len(pseudo_y) +# ori_char_acc += sum(pseudo_y == y) - if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0: - INFO(pseudo_y, y, abduced_y) - pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb")) +# if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0: +# INFO(pseudo_y, y, abduced_y) +# pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb")) - if sum(pseudo_y != y) != 0: - word_err_num += 1 +# if sum(pseudo_y != y) != 0: +# word_err_num += 1 - INFO("") - INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y)) - INFO("Abd char level accuracy:", abd_char_acc / abd_char_num) - INFO("Ori char level accuracy:", ori_char_acc / ori_char_num) - INFO("") +# INFO("") +# INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y)) +# INFO("Abd char level accuracy:", abd_char_acc / abd_char_num) +# INFO("Ori char level accuracy:", ori_char_acc / ori_char_num) +# INFO("") - result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num, - "total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc, - "total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc, - "total_abd_failed": abd_failed} +# result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num, +# "total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc, +# "total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc, +# "total_abd_failed": abd_failed} - return result +# return result @clocker def filter_data(X, abduced_Y): @@ -106,7 +115,7 @@ def is_all_sublabel_exist(labels, std_label_list): def pretrain(model, X, Y): pass -def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True): +def train(model, abducer, X, C, logic_forward, epochs = 10, sample_num = -1, verbose = -1): # Set default parameters if sample_num == -1: sample_num = len(X) @@ -114,42 +123,76 @@ def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose if verbose < 1: verbose = epochs - if C is None: - C = [None] * len(X) - - # Set function running time recorder - valid_func = clocker(model.valid) predict_func = clocker(model.predict) train_func = clocker(model.train) abduce_func = clocker(abducer.batch_abduce) X_bak = X - Y_bak = Y C_bak = C - + epochs = 50 # Abductive learning train process - res = {} for epoch_idx in range(epochs): - X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx) + X, C = block_sample(X_bak, C_bak, sample_num, epoch_idx) preds_res = predict_func(X) + + abl_acc = result_statistics(C, preds_res['cls'], logic_forward) + print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc) + abduced_Y = abduce_func(preds_res, C) - finetune_X, finetune_Y = filter_data(X, abduced_Y) - score, score_list = valid_func(X, Y) - if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1): - res = result_statistics(preds_res["cls"], Y, abduced_Y) - INFO(res) - - if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)): - INFO("There is some sub label missing", len(finetune_Y)) - break + + finetune_X, finetune_Y = filter_data(X, abduced_Y) + if len(finetune_X) > 0: train_func(finetune_X, finetune_Y)#, n_epoch = 10) else: INFO("lack of data, all abduced failed", len(finetune_X)) - return res - #return ret + return abl_acc + +# def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True): +# # Set default parameters +# if sample_num == -1: +# sample_num = len(X) + +# if verbose < 1: +# verbose = epochs + +# if C is None: +# C = [None] * len(X) + +# # Set function running time recorder +# valid_func = clocker(model.valid) +# predict_func = clocker(model.predict) +# train_func = clocker(model.train) + +# abduce_func = clocker(abducer.batch_abduce) + +# X_bak = X +# Y_bak = Y +# C_bak = C + +# # Abductive learning train process +# res = {} +# for epoch_idx in range(epochs): +# X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx) +# preds_res = predict_func(X) +# abduced_Y = abduce_func(preds_res, C) +# finetune_X, finetune_Y = filter_data(X, abduced_Y) +# score, score_list = valid_func(X, Y) +# if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1): +# res = result_statistics(preds_res["cls"], Y, abduced_Y) +# INFO(res) + +# if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)): +# INFO("There is some sub label missing", len(finetune_Y)) +# break + +# if len(finetune_X) > 0: +# train_func(finetune_X, finetune_Y)#, n_epoch = 10) +# else: +# INFO("lack of data, all abduced failed", len(finetune_X)) +# return res if __name__ == "__main__": pass