diff --git a/abl/framework_hed.py b/abl/framework_hed.py index 76c57fa..c8d1e2a 100644 --- a/abl/framework_hed.py +++ b/abl/framework_hed.py @@ -52,9 +52,9 @@ def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): def filter_data(X, abduced_Z): finetune_Z = [] finetune_X = [] - for abduced_x, abduced_z in zip(X, abduced_Z): - if abduced_z is not []: - finetune_X.append(abduced_x) + for x, abduced_z in zip(X, abduced_Z): + if len(abduced_z) > 0: + finetune_X.append(x) finetune_Z.append(abduced_z) return finetune_X, finetune_Z @@ -84,7 +84,6 @@ def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbo for epoch_idx in range(epochs): X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx) preds_res = predict_func(X) - # input() abduced_Z = abduce_func(preds_res, Y) if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):