From 54492a5a37effaa56554cd8ec95aab32b3093fb3 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 22 Feb 2023 14:58:27 +0800 Subject: [PATCH] Update framework_hed.py --- framework_hed.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/framework_hed.py b/framework_hed.py index e464ed0..3476462 100644 --- a/framework_hed.py +++ b/framework_hed.py @@ -122,8 +122,8 @@ def hed_pretrain(kb, cls, recorder): def get_char_acc(model, X, consistent_pred_res, mapping): original_pred_res = model.predict(X)['cls'] pred_res = flatten(mapping_res(original_pred_res, mapping)) - INFO('Current model\'s output:', pred_res) - INFO('Abduced labels: ', flatten(consistent_pred_res)) + INFO('Current model\'s output: ', pred_res) + INFO('Abduced labels: ', flatten(consistent_pred_res)) assert len(pred_res) == len(flatten(consistent_pred_res)) return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res) @@ -218,7 +218,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, def get_mlp_vector(model, abducer, mapping, X, rules): original_pred_res = model.predict([X])['cls'] - pred_res = mapping_res(original_pred_res, mapping) + pred_res = flatten(mapping_res(original_pred_res, mapping)) vector = [] for rule in rules: if abducer.kb.consist_rule(pred_res, rule): @@ -267,7 +267,7 @@ def validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, t optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=60) + mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100) mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True) loss = mlp_model.fit(mlp_train_data_loader) @@ -329,7 +329,7 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len best_accuracy = validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false) INFO('best_accuracy is %f\n' %(best_accuracy)) # decide next course or restart - if best_accuracy > 0.85: + if best_accuracy > 0.88: torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len) break else: @@ -378,7 +378,7 @@ def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len= optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=60) + mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100) mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True) loss = mlp_model.fit(mlp_train_data_loader)