Browse Source

Update several files

pull/3/head
troyyyyy 2 years ago
parent
commit
24ce8ff87e
3 changed files with 5 additions and 4 deletions
  1. +2
    -2
      abl/abducer/abducer_base.py
  2. +1
    -1
      abl/framework_hed.py
  3. +2
    -1
      abl/utils/utils.py

+ 2
- 2
abl/abducer/abducer_base.py View File

@@ -68,8 +68,8 @@ class AbducerBase(abc.ABC):
if len(candidates) > 0:
score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates))
else:
score += len(pred_res)
return -self._zoopt_score_multiple(pred_res, key, sol.get_x())
score += len(pred_res[idx])
return score
def _constrain_address_num(self, solution, max_address_num):
x = solution.get_x()


+ 1
- 1
abl/framework_hed.py View File

@@ -291,7 +291,7 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len
INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc))
# decide next course or restart
if true_consist_rule_acc > 0.9 and false_consist_rule_acc < 0.1:
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1:
torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len)
break
else:


+ 2
- 1
abl/utils/utils.py View File

@@ -19,7 +19,8 @@ def reform_idx(flatten_pred_res, save_pred_res):
return re

def hamming_dist(A, B):
B = np.array(B)
A = np.array(A, dtype='<U')
B = np.array(B, dtype='<U')
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)



Loading…
Cancel
Save