| @@ -55,6 +55,11 @@ class AbducerBase(abc.ABC): | |||||
| self.cache_min_address_num = {} | self.cache_min_address_num = {} | ||||
| self.cache_candidates = {} | self.cache_candidates = {} | ||||
| def get_min_cost_candidate(self, pred_res, candidates): | |||||
| cost_list = self.dist_func(pred_res, candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| idxs = np.where(cost_list == min_address_num)[0] | |||||
| return [candidates[idx] for idx in idxs][0] | |||||
| def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): | def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): | ||||
| pred_res, ans = data | pred_res, ans = data | ||||
| @@ -78,18 +83,13 @@ class AbducerBase(abc.ABC): | |||||
| else: | else: | ||||
| candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) | candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) | ||||
| cost_list = self.dist_func(pred_res, candidates) | |||||
| if(self.cache): | if(self.cache): | ||||
| self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | ||||
| self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates | self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates | ||||
| cost_list = self.dist_func(pred_res, candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| idxs = np.where(cost_list == min_address_num)[0] | |||||
| candidates = [candidates[idx] for idx in idxs] | |||||
| return candidates[0] | |||||
| candidates = self.get_min_cost_candidate(pred_res, candidates) | |||||
| return candidates | |||||
| def address(self, address_num, pred_res, key): | def address(self, address_num, pred_res, key): | ||||
| new_candidates = [] | new_candidates = [] | ||||
| @@ -168,4 +168,6 @@ if __name__ == "__main__": | |||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0) | res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0) | ||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3) | |||||
| print(res) | |||||
| print() | print() | ||||