Browse Source

Rearrange address_by_idx to abducer_base

pull/3/head
troyyyyy 3 years ago
parent
commit
b75b2897b9
2 changed files with 7 additions and 5 deletions
  1. +6
    -4
      abl/abducer/abducer_base.py
  2. +1
    -1
      abl/framework_hed.py

+ 6
- 4
abl/abducer/abducer_base.py View File

@@ -64,7 +64,7 @@ class AbducerBase(abc.ABC):
score = 0
for idx in range(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True)
candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx)
if len(candidate) > 0:
score += 1
return score
@@ -72,7 +72,7 @@ class AbducerBase(abc.ABC):
def _zoopt_address_score(self, pred_res, key, sol):
if not self.multiple_predictions:
address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0]
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)
candidates = self.address_by_idx(pred_res, key, address_idx)
return 1 if len(candidates) > 0 else 0
else:
return self._zoopt_score_multiple(pred_res, key, sol.get_x())
@@ -115,6 +115,9 @@ class AbducerBase(abc.ABC):
key = tuple(key)
self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num
self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates
def address_by_idx(self, pred_res, key, address_idx):
return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)

def abduce(self, data, max_address_num=-1, require_more_address=0):
pred_res, pred_res_prob, key = data
@@ -129,7 +132,7 @@ class AbducerBase(abc.ABC):
if self.zoopt:
solution = self.zoopt_get_solution(pred_res, key, max_address_num)
address_idx = [idx for idx, i in enumerate(solution) if i != 0]
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)
candidates = self.address_by_idx(pred_res, key, address_idx)
address_num = int(solution.sum())
min_address_num = address_num
else:
@@ -156,7 +159,6 @@ class AbducerBase(abc.ABC):
def __call__(self, Z, Y, max_address_num=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address)


if __name__ == '__main__':
from kb import add_KB, prolog_KB, HWF_KB


+ 1
- 1
abl/framework_hed.py View File

@@ -158,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
for idx in range(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = abducer.kb.address_by_idx([pred_res[idx]], None, address_idx, True)
candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx)
if len(candidate) > 0:
consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(candidate[0][0])


Loading…
Cancel
Save