Browse Source

Update framework_hed.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
54492a5a37
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      framework_hed.py

+ 6
- 6
framework_hed.py View File

@@ -122,8 +122,8 @@ def hed_pretrain(kb, cls, recorder):
def get_char_acc(model, X, consistent_pred_res, mapping):
original_pred_res = model.predict(X)['cls']
pred_res = flatten(mapping_res(original_pred_res, mapping))
INFO('Current model\'s output:', pred_res)
INFO('Abduced labels: ', flatten(consistent_pred_res))
INFO('Current model\'s output: ', pred_res)
INFO('Abduced labels: ', flatten(consistent_pred_res))
assert len(pred_res) == len(flatten(consistent_pred_res))
return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res)

@@ -218,7 +218,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule,

def get_mlp_vector(model, abducer, mapping, X, rules):
original_pred_res = model.predict([X])['cls']
pred_res = mapping_res(original_pred_res, mapping)
pred_res = flatten(mapping_res(original_pred_res, mapping))
vector = []
for rule in rules:
if abducer.kb.consist_rule(pred_res, rule):
@@ -267,7 +267,7 @@ def validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, t
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=60)
mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100)
mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True)
loss = mlp_model.fit(mlp_train_data_loader)
@@ -329,7 +329,7 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len
best_accuracy = validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false)
INFO('best_accuracy is %f\n' %(best_accuracy))
# decide next course or restart
if best_accuracy > 0.85:
if best_accuracy > 0.88:
torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len)
break
else:
@@ -378,7 +378,7 @@ def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=60)
mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100)
mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True)
loss = mlp_model.fit(mlp_train_data_loader)


Loading…
Cancel
Save