From 5efe40fc637780c7d96c58dcbdcfc19e52097bc7 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:59:13 +0800 Subject: [PATCH] Update abducer_base.py --- abducer/abducer_base.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 37c4bf8..576cce4 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -55,6 +55,11 @@ class AbducerBase(abc.ABC): self.cache_min_address_num = {} self.cache_candidates = {} + def get_min_cost_candidate(self, pred_res, candidates): + cost_list = self.dist_func(pred_res, candidates) + min_address_num = np.min(cost_list) + idxs = np.where(cost_list == min_address_num)[0] + return [candidates[idx] for idx in idxs][0] def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): pred_res, ans = data @@ -78,18 +83,13 @@ class AbducerBase(abc.ABC): else: candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) - cost_list = self.dist_func(pred_res, candidates) if(self.cache): self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates - - cost_list = self.dist_func(pred_res, candidates) - min_address_num = np.min(cost_list) - idxs = np.where(cost_list == min_address_num)[0] - candidates = [candidates[idx] for idx in idxs] - - return candidates[0] + + candidates = self.get_min_cost_candidate(pred_res, candidates) + return candidates def address(self, address_num, pred_res, key): new_candidates = [] @@ -168,4 +168,6 @@ if __name__ == "__main__": print(res) res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0) print(res) + res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3) + print(res) print()