From 35bdb8d01fe02bbfc49f58df7ceffedd65744093 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 18 Nov 2022 15:52:01 +0800 Subject: [PATCH] Update framework.py --- framework.py | 83 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 35 deletions(-) diff --git a/framework.py b/framework.py index 92a20b5..5479316 100644 --- a/framework.py +++ b/framework.py @@ -17,33 +17,41 @@ import numpy as np from utils.plog import INFO, DEBUG, clocker @clocker -def block_sample(X_bak, C_bak, sample_num, epoch_idx): - part_num = (len(X_bak) // sample_num) +def block_sample(X, Z, Y, sample_num, epoch_idx): + part_num = (len(X) // sample_num) if part_num == 0: part_num = 1 seg_idx = epoch_idx % part_num INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak)) - X = X_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] - # Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] - C = C_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] + X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)] + Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)] + Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)] - return X, C + return X, Z, Y -def get_taglist(self, Y): - tmp = [[str(x) for x in label] for label in Y] +def get_taglist(self, Z): + tmp = [[str(x) for x in label] for label in Z] tmp = sorted(list(set(tmp))) return tmp -@clocker -def result_statistics(C, pseudo_Y, logic_forward): +def get_abl_acc(Y, pseudo_Z, logic_forward): abl_acc = 0 - for tidx, (c, pseudo_y) in enumerate(zip(C, pseudo_Y)): - if(logic_forward(pseudo_y) == c): - abl_acc += 1 - + for y, pseudo_z in zip(Y, pseudo_Z): + if(logic_forward(pseudo_z) == y): + abl_acc += 1 + return abl_acc / len(Y) + +def get_char_acc(Z, pseudo_Z): + char_acc = 0 + char_num = 0 + for z, pseudo_z in zip(Z, pseudo_Z): + char_num += len(z) + for zidx in range(len(z)): + if(z[zidx] == pseudo_z[zidx]): + char_acc += 1 + return char_acc / char_num + - return abl_acc / len(C) - # def result_statistics(pseudo_Y, Y, abduced_Y): # abd_err_num = 0 @@ -89,14 +97,14 @@ def result_statistics(C, pseudo_Y, logic_forward): # return result @clocker -def filter_data(X, abduced_Y): - finetune_Y = [] +def filter_data(X, abduced_Z): + finetune_Z = [] finetune_X = [] - for abduced_x, abduced_y in zip(X, abduced_Y): - if abduced_y is not None: + for abduced_x, abduced_z in zip(X, abduced_Z): + if abduced_z is not None: finetune_X.append(abduced_x) - finetune_Y.append(abduced_y) - return finetune_X, finetune_Y + finetune_Z.append(abduced_z) + return finetune_X, finetune_Z @clocker def is_all_sublabel_exist(labels, std_label_list): @@ -112,42 +120,47 @@ def is_all_sublabel_exist(labels, std_label_list): return False return True -def pretrain(model, X, Y): +def pretrain(model, X, Z): pass -def train(model, abducer, X, C, epochs = 10, sample_num = -1, verbose = -1): +def train(model, abducer, X, Z, Y, epochs = 10, sample_num = -1, verbose = -1): # Set default parameters if sample_num == -1: sample_num = len(X) if verbose < 1: verbose = epochs + + char_acc_flag = 1 + if Z == None: + char_acc_flag = 0 + Z = [None] * len(X) predict_func = clocker(model.predict) train_func = clocker(model.train) abduce_func = clocker(abducer.batch_abduce) - X_bak = X - C_bak = C epochs = 50 + # Abductive learning train process for epoch_idx in range(epochs): - X, C = block_sample(X_bak, C_bak, sample_num, epoch_idx) + X, Z, Y = block_sample(X, Z, Y, sample_num, epoch_idx) preds_res = predict_func(X) - - abl_acc = result_statistics(C, preds_res['cls'], abducer.kb.logic_forward) - print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc) - - abduced_Y = abduce_func(preds_res, C) - + abduced_Z = abduce_func(preds_res, Y) - finetune_X, finetune_Y = filter_data(X, abduced_Y) + abl_acc = get_abl_acc(Y, preds_res['cls'], abducer.kb.logic_forward) + if(not char_acc_flag): + ori_char_acc = get_char_acc(Z, preds_res['cls']) + abd_char_acc = get_char_acc(abduced_Z, preds_res['cls']) + print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc) + finetune_X, finetune_Z = filter_data(X, abduced_Z) if len(finetune_X) > 0: - train_func(finetune_X, finetune_Y)#, n_epoch = 10) + train_func(finetune_X, finetune_Z) else: INFO("lack of data, all abduced failed", len(finetune_X)) + return abl_acc # def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True):