|
|
|
@@ -55,6 +55,11 @@ class AbducerBase(abc.ABC): |
|
|
|
self.cache_min_address_num = {} |
|
|
|
self.cache_candidates = {} |
|
|
|
|
|
|
|
def get_min_cost_candidate(self, pred_res, candidates): |
|
|
|
cost_list = self.dist_func(pred_res, candidates) |
|
|
|
min_address_num = np.min(cost_list) |
|
|
|
idxs = np.where(cost_list == min_address_num)[0] |
|
|
|
return [candidates[idx] for idx in idxs][0] |
|
|
|
|
|
|
|
def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): |
|
|
|
pred_res, ans = data |
|
|
|
@@ -78,18 +83,13 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
else: |
|
|
|
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) |
|
|
|
cost_list = self.dist_func(pred_res, candidates) |
|
|
|
|
|
|
|
if(self.cache): |
|
|
|
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num |
|
|
|
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates |
|
|
|
|
|
|
|
cost_list = self.dist_func(pred_res, candidates) |
|
|
|
min_address_num = np.min(cost_list) |
|
|
|
idxs = np.where(cost_list == min_address_num)[0] |
|
|
|
candidates = [candidates[idx] for idx in idxs] |
|
|
|
|
|
|
|
return candidates[0] |
|
|
|
|
|
|
|
candidates = self.get_min_cost_candidate(pred_res, candidates) |
|
|
|
return candidates |
|
|
|
|
|
|
|
def address(self, address_num, pred_res, key): |
|
|
|
new_candidates = [] |
|
|
|
@@ -168,4 +168,6 @@ if __name__ == "__main__": |
|
|
|
print(res) |
|
|
|
res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0) |
|
|
|
print(res) |
|
|
|
res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3) |
|
|
|
print(res) |
|
|
|
print() |