|
- # coding: utf-8
- #================================================================#
- # Copyright (C) 2021 Freecss All rights reserved.
- #
- # File Name :framework.py
- # Author :freecss
- # Email :karlfreecss@gmail.com
- # Created Date :2021/06/07
- # Description :
- #
- #================================================================#
-
- import pickle as pk
-
- import numpy as np
-
- from utils.plog import INFO, DEBUG, clocker
-
- @clocker
- def block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx):
- part_num = (len(X_bak) // sample_num)
- if part_num == 0:
- part_num = 1
- seg_idx = epoch_idx % part_num
- INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak))
- X = X_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
- Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
- C = C_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
-
- return X, Y, C
-
- def get_taglist(self, Y):
- tmp = [[str(x) for x in label] for label in Y]
- tmp = sorted(list(set(tmp)))
- return tmp
-
- @clocker
- def result_statistics(pseudo_Y, Y, abduced_Y):
-
- abd_err_num = 0
- abd_char_num = 0
- abd_char_acc = 0
- abd_failed = 0
- word_err_num = 0
-
- ori_char_num = 0
- ori_char_acc = 0
-
- for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)):
- pseudo_y = pseudo_y
- if sum(abduced_y != y) != 0:
- abd_err_num += 1
- if abduced_y is not None:
- abd_char_num += len(y)
- abd_char_acc += sum(abduced_y == y)
- else:
- abd_failed += 1
-
- ori_char_num += len(pseudo_y)
- ori_char_acc += sum(pseudo_y == y)
-
- if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0:
- INFO(pseudo_y, y, abduced_y)
- pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb"))
-
- if sum(pseudo_y != y) != 0:
- word_err_num += 1
-
- INFO("")
- INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y))
- INFO("Abd char level accuracy:", abd_char_acc / abd_char_num)
- INFO("Ori char level accuracy:", ori_char_acc / ori_char_num)
- INFO("")
-
- result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num,
- "total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc,
- "total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc,
- "total_abd_failed": abd_failed}
-
- return result
-
- @clocker
- def filter_data(X, abduced_Y):
- finetune_Y = []
- finetune_X = []
- for abduced_x, abduced_y in zip(X, abduced_Y):
- if abduced_y is not None:
- finetune_X.append(abduced_x)
- finetune_Y.append(abduced_y)
- return finetune_X, finetune_Y
-
- @clocker
- def is_all_sublabel_exist(labels, std_label_list):
- if not labels:
- return False
-
- labels = np.array(labels).T
- for idx, (std_label, label) in enumerate(zip(std_label_list, labels)):
- std_num = len(set(std_label))
- sublabel_num = len(set(label))
- if std_num != sublabel_num:
- INFO(f"sublabel {idx} should have {std_num} class, but data only have {sublabel_num} class", screen=True)
- return False
- return True
-
- def pretrain(model, X, Y):
- pass
-
- def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True):
- # Set default parameters
- if sample_num == -1:
- sample_num = len(X)
-
- if verbose < 1:
- verbose = epochs
-
- if C is None:
- C = [None] * len(X)
-
- # Set function running time recorder
- valid_func = clocker(model.valid)
- predict_func = clocker(model.predict)
- train_func = clocker(model.train)
-
- abduce_func = clocker(abducer.batch_abduce)
-
- X_bak = X
- Y_bak = Y
- C_bak = C
-
- # Abductive learning train process
- res = {}
- for epoch_idx in range(epochs):
- X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx)
- preds_res = predict_func(X)
- abduced_Y = abduce_func(preds_res, C)
- finetune_X, finetune_Y = filter_data(X, abduced_Y)
- score, score_list = valid_func(X, Y)
- if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
- res = result_statistics(preds_res["cls"], Y, abduced_Y)
- INFO(res)
-
- if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)):
- INFO("There is some sub label missing", len(finetune_Y))
- break
-
- if len(finetune_X) > 0:
- train_func(finetune_X, finetune_Y)#, n_epoch = 10)
- else:
- INFO("lack of data, all abduced failed", len(finetune_X))
- return res
- #return ret
-
- if __name__ == "__main__":
- pass
|