|
|
|
@@ -64,7 +64,7 @@ class AbducerBase(abc.ABC): |
|
|
|
score = 0 |
|
|
|
for idx in range(len(pred_res)): |
|
|
|
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] |
|
|
|
candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) |
|
|
|
candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) |
|
|
|
if len(candidate) > 0: |
|
|
|
score += 1 |
|
|
|
return score |
|
|
|
@@ -72,7 +72,7 @@ class AbducerBase(abc.ABC): |
|
|
|
def _zoopt_address_score(self, pred_res, key, sol): |
|
|
|
if not self.multiple_predictions: |
|
|
|
address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] |
|
|
|
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) |
|
|
|
candidates = self.address_by_idx(pred_res, key, address_idx) |
|
|
|
return 1 if len(candidates) > 0 else 0 |
|
|
|
else: |
|
|
|
return self._zoopt_score_multiple(pred_res, key, sol.get_x()) |
|
|
|
@@ -115,6 +115,9 @@ class AbducerBase(abc.ABC): |
|
|
|
key = tuple(key) |
|
|
|
self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num |
|
|
|
self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates |
|
|
|
|
|
|
|
def address_by_idx(self, pred_res, key, address_idx): |
|
|
|
return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) |
|
|
|
|
|
|
|
def abduce(self, data, max_address_num=-1, require_more_address=0): |
|
|
|
pred_res, pred_res_prob, key = data |
|
|
|
@@ -129,7 +132,7 @@ class AbducerBase(abc.ABC): |
|
|
|
if self.zoopt: |
|
|
|
solution = self.zoopt_get_solution(pred_res, key, max_address_num) |
|
|
|
address_idx = [idx for idx, i in enumerate(solution) if i != 0] |
|
|
|
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) |
|
|
|
candidates = self.address_by_idx(pred_res, key, address_idx) |
|
|
|
address_num = int(solution.sum()) |
|
|
|
min_address_num = address_num |
|
|
|
else: |
|
|
|
@@ -156,7 +159,6 @@ class AbducerBase(abc.ABC): |
|
|
|
def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): |
|
|
|
return self.batch_abduce(Z, Y, max_address_num, require_more_address) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
from kb import add_KB, prolog_KB, HWF_KB |
|
|
|
|
|
|
|
|