diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index b1c661b..38cbf64 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -45,7 +45,10 @@ class AbducerBase(abc.ABC): dist_func = confidence_dist self.dist_func = dist_func if pred_res_parse is None: - pred_res_parse = lambda x : x["cls"] + if(dist_func == "hamming"): + pred_res_parse = lambda x : x["cls"] + elif dist_func == "confidence": + pred_res_parse = lambda x : x[" "] self.pred_res_parse = pred_res_parse self.cache = cache @@ -75,20 +78,18 @@ class AbducerBase(abc.ABC): else: candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) + cost_list = self.dist_func(pred_res, candidates) if(self.cache): self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates - - return candidates - + cost_list = self.dist_func(pred_res, candidates) + min_address_num = np.min(cost_list) + idxs = np.where(cost_list == min_address_num)[0] + candidates = [candidates[idx] for idx in idxs] - - - # if len(idxs) > 1: - # return None - # return [candidates[idx] for idx in idxs] + return candidates[0] def address(self, address_num, pred_res, key): new_candidates = [] @@ -99,7 +100,7 @@ class AbducerBase(abc.ABC): pred_res_array = np.array(pred_res) if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): pred_res_array[np.array(address_idx)] = c - if(self.kb.logic_forward(pred_res_array) == key): + if(abs(self.kb.logic_forward(pred_res_array) - key) <= 1e-3): new_candidates.append(pred_res_array) return new_candidates, address_num @@ -113,7 +114,7 @@ class AbducerBase(abc.ABC): return None, None, None if(address_num == 0): - if(self.kb.logic_forward(pred_res) == key): + if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3): candidates.append(pred_res) else: new_candidates, address_num = self.address(address_num, pred_res, key) @@ -148,7 +149,7 @@ if __name__ == "__main__": pseudo_label_list = list(range(10)) kb = add_KB(pseudo_label_list) abd = AbducerBase(kb) - res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1) + res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) print(res) print() res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1)