| @@ -14,7 +14,7 @@ import sys | |||
| sys.path.append("..") | |||
| import abc | |||
| from abducer.kb import add_KB, hwf_KB | |||
| from abducer.kb import add_KB, hwf_KB, add_prolog_KB | |||
| import numpy as np | |||
| from itertools import product, combinations | |||
| @@ -67,35 +67,31 @@ class AbducerBase(abc.ABC): | |||
| min_address_num = np.min(cost_list) | |||
| idxs = np.where(cost_list == min_address_num)[0] | |||
| return [candidates[idx] for idx in idxs][0] | |||
| def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address): | |||
| if len(all_candidates) == 0: | |||
| candidates = [] | |||
| min_address_num = 0 | |||
| address_num = 0 | |||
| else: | |||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| def abduce(self, data, max_address_num = -1, require_more_address = 0): | |||
| pred_res, pred_res_prob, ans = data | |||
| if max_address_num == -1: | |||
| max_address_num = len(pred_res) | |||
| if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num: | |||
| address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address) | |||
| if (tuple(pred_res), ans, address_num) in self.cache_candidates: | |||
| # print('cached') | |||
| candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] | |||
| candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||
| return candidate | |||
| if self.kb.GKB_flag: | |||
| all_candidates = self.kb.get_candidates(ans, len(pred_res)) | |||
| if len(all_candidates) == 0: | |||
| candidates = [] | |||
| min_address_num = 0 | |||
| address_num = 0 | |||
| else: | |||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| else: | |||
| candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) | |||
| candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, ans, max_address_num, require_more_address) | |||
| if self.cache: | |||
| self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | |||
| @@ -104,47 +100,6 @@ class AbducerBase(abc.ABC): | |||
| candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||
| return candidate | |||
| def address(self, address_num, pred_res, key): | |||
| new_candidates = [] | |||
| all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num)) | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_address_candidate: | |||
| address_list = [pred_res[i] for i in address_idx] | |||
| if(sum([address_list[i] == c[i] for i in range(address_num)]) == 0): | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if self.kb.logic_forward(candidate) == key: | |||
| new_candidates.append(candidate) | |||
| return new_candidates | |||
| def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||
| candidates = [] | |||
| print(pred_res) | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if abs(self.kb.logic_forward(pred_res) - key) <= 1e-3: | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0): | |||
| return [ | |||
| @@ -154,21 +109,30 @@ class AbducerBase(abc.ABC): | |||
| def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0): | |||
| return self.batch_abduce(Z, Y, max_address_num, require_more_address) | |||
| if __name__ == '__main__': | |||
| kb = add_KB(GKB_flag = True) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0) | |||
| res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 1) | |||
| res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], None, 5), max_address_num = 2, require_more_address = 1) | |||
| res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) | |||
| print(res) | |||
| print() | |||
| kb = hwf_KB() | |||
| kb = add_prolog_KB() | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) | |||
| print(res) | |||
| print() | |||
| kb = hwf_KB(len_list = [1, 3, 5]) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0) | |||
| print(res) | |||
| @@ -176,6 +140,8 @@ if __name__ == '__main__': | |||
| print(res) | |||
| res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0) | |||
| print(res) | |||
| res = abd.abduce((['5', '+', '3'], None, 0.33), max_address_num = 3, require_more_address = 3) | |||
| res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num = 5, require_more_address = 3) | |||
| print(res) | |||
| print() | |||