Browse Source

Add `_get_zoopt_score`

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
88a8f33f16
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 19 deletions
  1. +16
    -19
      abl/abducer/abducer_base.py

+ 16
- 19
abl/abducer/abducer_base.py View File

@@ -50,26 +50,23 @@ class AbducerBase(abc.ABC):
cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates)
candidate = candidates[np.argmin(cost_list)]
return candidate

def _get_zoopt_score(self, sol_x, pred_res, pred_res_prob, key):
address_idx = np.where(sol_x != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
if not self.multiple_predictions:
address_idx = np.where(sol.get_x() != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
return self._get_address_score(sol.get_x(), pred_res, pred_res_prob, key)
else:
all_address_flag = reform_idx(sol.get_x(), pred_res)
score = 0
# TODO:这个循环里,和上面if not self.multiple_predictions部分逻辑完全一样吧,应该把上面封装一下,然后下面循环里调用封装方法即可
for idx in range(len(pred_res)):
address_idx = np.where(all_address_flag[idx] != 0)[0]
candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx)
if len(candidates) > 0:
score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates))
else:
score += len(pred_res[idx])
score += self._get_address_score(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key)
return score
def _constrain_address_num(self, solution, max_address_num):
@@ -112,10 +109,10 @@ class AbducerBase(abc.ABC):
return self.kb.abduce_rules(pred_res)

def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0):
# if self.multiple_predictions:
return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address)
# else:
# return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
if self.multiple_predictions:
return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address)
else:
return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]

def __call__(self, Z, Y, max_address_num=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address)
@@ -241,4 +238,4 @@ if __name__ == '__main__':
print()

abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules)
print(abduced_rules)

Loading…
Cancel
Save