Browse Source

Update framework.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
35bdb8d01f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 48 additions and 35 deletions
  1. +48
    -35
      framework.py

+ 48
- 35
framework.py View File

@@ -17,33 +17,41 @@ import numpy as np
from utils.plog import INFO, DEBUG, clocker

@clocker
def block_sample(X_bak, C_bak, sample_num, epoch_idx):
part_num = (len(X_bak) // sample_num)
def block_sample(X, Z, Y, sample_num, epoch_idx):
part_num = (len(X) // sample_num)
if part_num == 0:
part_num = 1
seg_idx = epoch_idx % part_num
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak))
X = X_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
# Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
C = C_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)]
Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)]
Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)]

return X, C
return X, Z, Y

def get_taglist(self, Y):
tmp = [[str(x) for x in label] for label in Y]
def get_taglist(self, Z):
tmp = [[str(x) for x in label] for label in Z]
tmp = sorted(list(set(tmp)))
return tmp

@clocker
def result_statistics(C, pseudo_Y, logic_forward):
def get_abl_acc(Y, pseudo_Z, logic_forward):
abl_acc = 0
for tidx, (c, pseudo_y) in enumerate(zip(C, pseudo_Y)):
if(logic_forward(pseudo_y) == c):
abl_acc += 1
for y, pseudo_z in zip(Y, pseudo_Z):
if(logic_forward(pseudo_z) == y):
abl_acc += 1
return abl_acc / len(Y)

def get_char_acc(Z, pseudo_Z):
char_acc = 0
char_num = 0
for z, pseudo_z in zip(Z, pseudo_Z):
char_num += len(z)
for zidx in range(len(z)):
if(z[zidx] == pseudo_z[zidx]):
char_acc += 1
return char_acc / char_num
return abl_acc / len(C)

# def result_statistics(pseudo_Y, Y, abduced_Y):

# abd_err_num = 0
@@ -89,14 +97,14 @@ def result_statistics(C, pseudo_Y, logic_forward):
# return result

@clocker
def filter_data(X, abduced_Y):
finetune_Y = []
def filter_data(X, abduced_Z):
finetune_Z = []
finetune_X = []
for abduced_x, abduced_y in zip(X, abduced_Y):
if abduced_y is not None:
for abduced_x, abduced_z in zip(X, abduced_Z):
if abduced_z is not None:
finetune_X.append(abduced_x)
finetune_Y.append(abduced_y)
return finetune_X, finetune_Y
finetune_Z.append(abduced_z)
return finetune_X, finetune_Z

@clocker
def is_all_sublabel_exist(labels, std_label_list):
@@ -112,42 +120,47 @@ def is_all_sublabel_exist(labels, std_label_list):
return False
return True

def pretrain(model, X, Y):
def pretrain(model, X, Z):
pass

def train(model, abducer, X, C, epochs = 10, sample_num = -1, verbose = -1):
def train(model, abducer, X, Z, Y, epochs = 10, sample_num = -1, verbose = -1):
# Set default parameters
if sample_num == -1:
sample_num = len(X)

if verbose < 1:
verbose = epochs
char_acc_flag = 1
if Z == None:
char_acc_flag = 0
Z = [None] * len(X)

predict_func = clocker(model.predict)
train_func = clocker(model.train)

abduce_func = clocker(abducer.batch_abduce)

X_bak = X
C_bak = C
epochs = 50
# Abductive learning train process
for epoch_idx in range(epochs):
X, C = block_sample(X_bak, C_bak, sample_num, epoch_idx)
X, Z, Y = block_sample(X, Z, Y, sample_num, epoch_idx)
preds_res = predict_func(X)
abl_acc = result_statistics(C, preds_res['cls'], abducer.kb.logic_forward)
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc)
abduced_Y = abduce_func(preds_res, C)
abduced_Z = abduce_func(preds_res, Y)

finetune_X, finetune_Y = filter_data(X, abduced_Y)
abl_acc = get_abl_acc(Y, preds_res['cls'], abducer.kb.logic_forward)
if(not char_acc_flag):
ori_char_acc = get_char_acc(Z, preds_res['cls'])
abd_char_acc = get_char_acc(abduced_Z, preds_res['cls'])
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc)
finetune_X, finetune_Z = filter_data(X, abduced_Z)
if len(finetune_X) > 0:
train_func(finetune_X, finetune_Y)#, n_epoch = 10)
train_func(finetune_X, finetune_Z)
else:
INFO("lack of data, all abduced failed", len(finetune_X))
return abl_acc

# def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True):


Loading…
Cancel
Save