From 24ce8ff87ef318d57338cbcdb8da5da58b29a11a Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 8 Mar 2023 10:18:08 +0000 Subject: [PATCH] Update several files --- abl/abducer/abducer_base.py | 4 ++-- abl/framework_hed.py | 2 +- abl/utils/utils.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index a70db14..61cab25 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -68,8 +68,8 @@ class AbducerBase(abc.ABC): if len(candidates) > 0: score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates)) else: - score += len(pred_res) - return -self._zoopt_score_multiple(pred_res, key, sol.get_x()) + score += len(pred_res[idx]) + return score def _constrain_address_num(self, solution, max_address_num): x = solution.get_x() diff --git a/abl/framework_hed.py b/abl/framework_hed.py index d339909..ab88318 100644 --- a/abl/framework_hed.py +++ b/abl/framework_hed.py @@ -291,7 +291,7 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc)) # decide next course or restart - if true_consist_rule_acc > 0.9 and false_consist_rule_acc < 0.1: + if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1: torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len) break else: diff --git a/abl/utils/utils.py b/abl/utils/utils.py index d986065..d5209a6 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -19,7 +19,8 @@ def reform_idx(flatten_pred_res, save_pred_res): return re def hamming_dist(A, B): - B = np.array(B) + A = np.array(A, dtype='