Browse Source

zoopt_score in ABL-HED

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
7bccd9ca6e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 4 deletions
  1. +43
    -4
      abl/abducer/abducer_base.py

+ 43
- 4
abl/abducer/abducer_base.py View File

@@ -52,13 +52,52 @@ class AbducerBase(abc.ABC):
return len(pred_res)
def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
all_address_flag = reform_idx(sol.get_x(), pred_res)
if nested_length(pred_res) == 1:
return self._zoopt_address_score_single(all_address_flag, pred_res, pred_res_prob, key)
return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key)
else:
all_address_flag = reform_idx(sol.get_x(), pred_res)
lefted_idx = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idx:
temp_idx = []
temp_idx.append(lefted_idx.pop(0))
max_candidate_idx = []
found = False
for idx in range(-1, len(pred_res)):
if (not idx in temp_idx) and (idx >= 0):
temp_idx.append(idx)
pred = []
k = []
address_flag = []
for idx in temp_idx:
pred.append(pred_res[idx])
k.append(key[idx])
address_flag += list(all_address_flag[idx])
address_idx = np.where(np.array(address_flag) != 0)[0]
candidate = self.address_by_idx(pred, k, address_idx)
if len(candidate) == 0:
if len(temp_idx) > 1:
temp_idx.pop()
else:
if len(temp_idx) > len(max_candidate_idx):
found = True
max_candidate_idx = temp_idx.copy()
removed = [i for i in lefted_idx if i in max_candidate_idx]
if found:
candidate_size.append(len(removed) + 1)
lefted_idx = [i for i in lefted_idx if i not in max_candidate_idx]
candidate_size.sort()
score = 0
for idx in range(nested_length(pred_res)):
score += self._zoopt_address_score_single(all_address_flag[idx], [pred_res[idx]], [pred_res_prob[idx]], [key[idx]])
import math
for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]

# score = 0
# for idx in range(nested_length(pred_res)):
# score += self._zoopt_address_score_single(all_address_flag[idx], [pred_res[idx]], [pred_res_prob[idx]], [key[idx]])
return score
def _constrain_address_num(self, solution, max_address_num):


Loading…
Cancel
Save