|
|
|
@@ -52,9 +52,9 @@ def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): |
|
|
|
def filter_data(X, abduced_Z): |
|
|
|
finetune_Z = [] |
|
|
|
finetune_X = [] |
|
|
|
for abduced_x, abduced_z in zip(X, abduced_Z): |
|
|
|
if abduced_z is not []: |
|
|
|
finetune_X.append(abduced_x) |
|
|
|
for x, abduced_z in zip(X, abduced_Z): |
|
|
|
if len(abduced_z) > 0: |
|
|
|
finetune_X.append(x) |
|
|
|
finetune_Z.append(abduced_z) |
|
|
|
return finetune_X, finetune_Z |
|
|
|
|
|
|
|
@@ -84,7 +84,6 @@ def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbo |
|
|
|
for epoch_idx in range(epochs): |
|
|
|
X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx) |
|
|
|
preds_res = predict_func(X) |
|
|
|
# input() |
|
|
|
abduced_Z = abduce_func(preds_res, Y) |
|
|
|
|
|
|
|
if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1): |
|
|
|
|