|
|
|
@@ -291,7 +291,7 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len |
|
|
|
|
|
|
|
INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc)) |
|
|
|
# decide next course or restart |
|
|
|
if true_consist_rule_acc > 0.9 and false_consist_rule_acc < 0.1: |
|
|
|
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1: |
|
|
|
torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len) |
|
|
|
break |
|
|
|
else: |
|
|
|
|