Browse Source

Modify filter_data

pull/3/head
troyyyyy 3 years ago
parent
commit
ee25a2c74e
1 changed files with 3 additions and 4 deletions
  1. +3
    -4
      abl/framework_hed.py

+ 3
- 4
abl/framework_hed.py View File

@@ -52,9 +52,9 @@ def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
def filter_data(X, abduced_Z):
finetune_Z = []
finetune_X = []
for abduced_x, abduced_z in zip(X, abduced_Z):
if abduced_z is not []:
finetune_X.append(abduced_x)
for x, abduced_z in zip(X, abduced_Z):
if len(abduced_z) > 0:
finetune_X.append(x)
finetune_Z.append(abduced_z)
return finetune_X, finetune_Z

@@ -84,7 +84,6 @@ def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbo
for epoch_idx in range(epochs):
X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx)
preds_res = predict_func(X)
# input()
abduced_Z = abduce_func(preds_res, Y)

if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):


Loading…
Cancel
Save