|
- # coding: utf-8
- #================================================================#
- # Copyright (C) 2021 Freecss All rights reserved.
- #
- # File Name :framework.py
- # Author :freecss
- # Email :karlfreecss@gmail.com
- # Created Date :2021/06/07
- # Description :
- #
- #================================================================#
-
- import pickle as pk
-
- import numpy as np
-
- from utils.plog import INFO, DEBUG, clocker
-
- def block_sample(X, Z, Y, sample_num, epoch_idx):
- part_num = (len(X) // sample_num)
- if part_num == 0:
- part_num = 1
- seg_idx = epoch_idx % part_num
- INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X))
- X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)]
- Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)]
- Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)]
-
- return X, Z, Y
-
- def get_taglist(self, Z):
- tmp = [[str(x) for x in label] for label in Z]
- tmp = sorted(list(set(tmp)))
- return tmp
-
- def get_abl_acc(pred_Z, Y, logic_forward):
- abl_acc = 0
- for pred_z, y in zip(pred_Z, Y):
- if(logic_forward(pred_z) == y):
- abl_acc += 1
- return abl_acc / len(Y)
-
- def get_char_acc(Z, pred_Z):
- char_acc = 0
- char_num = 0
- for pred_z, z in zip(pred_Z, Z):
- char_num += len(z)
- for zidx in range(len(z)):
- if(pred_z[zidx] == z[zidx]):
- char_acc += 1
- return char_acc / char_num
-
-
- def filter_data(X, abduced_Z):
- finetune_Z = []
- finetune_X = []
- for abduced_x, abduced_z in zip(X, abduced_Z):
- if abduced_z is not []:
- finetune_X.append(abduced_x)
- finetune_Z.append(abduced_z)
- return finetune_X, finetune_Z
-
- def pretrain(model, X, Z):
- pass
-
- def train(model, abducer, train_data, test_data, epochs = 5, sample_num = -1, verbose = -1):
- X, Z, Y = train_data
- test_X, test_Z, test_Y = test_data
-
- # Set default parameters
- if sample_num == -1:
- sample_num = len(X)
-
- if verbose < 1:
- verbose = epochs
-
- char_acc_flag = 1
- if Z == None:
- char_acc_flag = 0
- Z = [None] * len(X)
-
- predict_func = clocker(model.predict)
- train_func = clocker(model.train)
- abduce_func = clocker(abducer.batch_abduce)
-
- # Abductive learning train process
- for epoch_idx in range(epochs):
- X, Z, Y = block_sample(X, Z, Y, sample_num, epoch_idx)
- preds_res = predict_func(X)
- abduced_Z = abduce_func(preds_res, Y)
-
- abl_acc = get_abl_acc(preds_res['cls'], Y, abducer.kb.logic_forward)
- if(char_acc_flag):
- ori_char_acc = get_char_acc(preds_res['cls'], Z)
- abd_char_acc = get_char_acc(preds_res['cls'], abduced_Z)
- INFO('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc)
- else:
- INFO('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc)
-
- finetune_X, finetune_Z = filter_data(X, abduced_Z)
- if len(finetune_X) > 0:
- # model.valid(finetune_X, finetune_Z)
- train_func(finetune_X, finetune_Z)
- else:
- INFO("lack of data, all abduced failed", len(finetune_X))
-
- return abl_acc
-
- if __name__ == "__main__":
- pass
|