Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
5efe40fc63
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 8 deletions
  1. +10
    -8
      abducer/abducer_base.py

+ 10
- 8
abducer/abducer_base.py View File

@@ -55,6 +55,11 @@ class AbducerBase(abc.ABC):
self.cache_min_address_num = {} self.cache_min_address_num = {}
self.cache_candidates = {} self.cache_candidates = {}


def get_min_cost_candidate(self, pred_res, candidates):
cost_list = self.dist_func(pred_res, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]


def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1):
pred_res, ans = data pred_res, ans = data
@@ -78,18 +83,13 @@ class AbducerBase(abc.ABC):
else: else:
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address)
cost_list = self.dist_func(pred_res, candidates)
if(self.cache): if(self.cache):
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates
cost_list = self.dist_func(pred_res, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
candidates = [candidates[idx] for idx in idxs]
return candidates[0]

candidates = self.get_min_cost_candidate(pred_res, candidates)
return candidates
def address(self, address_num, pred_res, key): def address(self, address_num, pred_res, key):
new_candidates = [] new_candidates = []
@@ -168,4 +168,6 @@ if __name__ == "__main__":
print(res) print(res)
res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0) res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0)
print(res) print(res)
res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3)
print(res)
print() print()

Loading…
Cancel
Save