| @@ -26,9 +26,16 @@ import time | |||||
| class AbducerBase(abc.ABC): | class AbducerBase(abc.ABC): | ||||
| def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): | |||||
| def __init__( | |||||
| self, | |||||
| kb, | |||||
| dist_func="confidence", | |||||
| zoopt=False, | |||||
| multiple_predictions=False, | |||||
| cache=True, | |||||
| ): | |||||
| self.kb = kb | self.kb = kb | ||||
| assert dist_func == 'hamming' or dist_func == 'confidence' | |||||
| assert dist_func == "hamming" or dist_func == "confidence" | |||||
| self.dist_func = dist_func | self.dist_func = dist_func | ||||
| self.zoopt = zoopt | self.zoopt = zoopt | ||||
| self.multiple_predictions = multiple_predictions | self.multiple_predictions = multiple_predictions | ||||
| @@ -39,11 +46,18 @@ class AbducerBase(abc.ABC): | |||||
| self.cache_candidates = {} | self.cache_candidates = {} | ||||
| def _get_cost_list(self, pred_res, pred_res_prob, candidates): | def _get_cost_list(self, pred_res, pred_res_prob, candidates): | ||||
| if self.dist_func == 'hamming': | |||||
| if self.dist_func == "hamming": | |||||
| return hamming_dist(pred_res, candidates) | return hamming_dist(pred_res, candidates) | ||||
| elif self.dist_func == 'confidence': | |||||
| mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) | |||||
| return confidence_dist(pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates]) | |||||
| elif self.dist_func == "confidence": | |||||
| mapping = dict( | |||||
| zip( | |||||
| self.kb.pseudo_label_list, | |||||
| list(range(len(self.kb.pseudo_label_list))), | |||||
| ) | |||||
| ) | |||||
| return confidence_dist( | |||||
| pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates] | |||||
| ) | |||||
| def _get_one_candidate(self, pred_res, pred_res_prob, candidates): | def _get_one_candidate(self, pred_res, pred_res_prob, candidates): | ||||
| if len(candidates) == 0: | if len(candidates) == 0: | ||||
| @@ -61,8 +75,12 @@ class AbducerBase(abc.ABC): | |||||
| all_address_flag = reform_idx(solution, pred_res) | all_address_flag = reform_idx(solution, pred_res) | ||||
| score = 0 | score = 0 | ||||
| for idx in enumerate(len(pred_res)): | for idx in enumerate(len(pred_res)): | ||||
| address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | |||||
| candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) | |||||
| address_idx = [ | |||||
| i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 | |||||
| ] | |||||
| candidate = self.kb.address_by_idx( | |||||
| [pred_res[idx]], key[idx], address_idx, True | |||||
| ) | |||||
| if len(candidate) > 0: | if len(candidate) > 0: | ||||
| score += 1 | score += 1 | ||||
| return score | return score | ||||
| @@ -70,7 +88,9 @@ class AbducerBase(abc.ABC): | |||||
| def _zoopt_address_score(self, pred_res, key, sol): | def _zoopt_address_score(self, pred_res, key, sol): | ||||
| if not self.multiple_predictions: | if not self.multiple_predictions: | ||||
| address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] | address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] | ||||
| candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) | |||||
| candidates = self.kb.address_by_idx( | |||||
| pred_res, key, address_idx, self.multiple_predictions | |||||
| ) | |||||
| return 1 if len(candidates) > 0 else 0 | return 1 if len(candidates) > 0 else 0 | ||||
| else: | else: | ||||
| return self._zoopt_score_multiple(pred_res, key, sol.get_x()) | return self._zoopt_score_multiple(pred_res, key, sol.get_x()) | ||||
| @@ -98,7 +118,11 @@ class AbducerBase(abc.ABC): | |||||
| pred_res = flatten(pred_res) | pred_res = flatten(pred_res) | ||||
| key = tuple(key) | key = tuple(key) | ||||
| if (tuple(pred_res), key) in self.cache_min_address_num: | if (tuple(pred_res), key) in self.cache_min_address_num: | ||||
| address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) | |||||
| address_num = min( | |||||
| max_address_num, | |||||
| self.cache_min_address_num[(tuple(pred_res), key)] | |||||
| + require_more_address, | |||||
| ) | |||||
| if (tuple(pred_res), key, address_num) in self.cache_candidates: | if (tuple(pred_res), key, address_num) in self.cache_candidates: | ||||
| candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] | candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] | ||||
| if self.zoopt: | if self.zoopt: | ||||
| @@ -127,12 +151,18 @@ class AbducerBase(abc.ABC): | |||||
| if self.zoopt: | if self.zoopt: | ||||
| solution = self.zoopt_get_solution(pred_res, key, max_address_num) | solution = self.zoopt_get_solution(pred_res, key, max_address_num) | ||||
| address_idx = [idx for idx, i in enumerate(solution) if i != 0] | address_idx = [idx for idx, i in enumerate(solution) if i != 0] | ||||
| candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) | |||||
| candidates = self.kb.address_by_idx( | |||||
| pred_res, key, address_idx, self.multiple_predictions | |||||
| ) | |||||
| address_num = int(solution.sum()) | address_num = int(solution.sum()) | ||||
| min_address_num = address_num | min_address_num = address_num | ||||
| else: | else: | ||||
| candidates, min_address_num, address_num = self.kb.abduce_candidates( | candidates, min_address_num, address_num = self.kb.abduce_candidates( | ||||
| pred_res, key, max_address_num, require_more_address, self.multiple_predictions | |||||
| pred_res, | |||||
| key, | |||||
| max_address_num, | |||||
| require_more_address, | |||||
| self.multiple_predictions, | |||||
| ) | ) | ||||
| candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) | candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) | ||||
| @@ -147,23 +177,31 @@ class AbducerBase(abc.ABC): | |||||
| def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): | def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): | ||||
| if self.multiple_predictions: | if self.multiple_predictions: | ||||
| return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) | |||||
| return self.abduce( | |||||
| (Z["cls"], Z["prob"], Y), max_address_num, require_more_address | |||||
| ) | |||||
| else: | else: | ||||
| return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] | |||||
| return [ | |||||
| self.abduce((z, prob, y), max_address_num, require_more_address) | |||||
| for z, prob, y in zip(Z["cls"], Z["prob"], Y) | |||||
| ] | |||||
| def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): | def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): | ||||
| return self.batch_abduce(Z, Y, max_address_num, require_more_address) | return self.batch_abduce(Z, Y, max_address_num, require_more_address) | ||||
| if __name__ == '__main__': | |||||
| prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||||
| prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||||
| if __name__ == "__main__": | |||||
| prob1 = [ | |||||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ] | |||||
| prob2 = [ | |||||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ] | |||||
| kb = add_KB() | kb = add_KB() | ||||
| abd = AbducerBase(kb, 'confidence') | |||||
| abd = AbducerBase(kb, "confidence") | |||||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | ||||
| print(res) | print(res) | ||||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | ||||
| @@ -177,7 +215,7 @@ if __name__ == '__main__': | |||||
| print() | print() | ||||
| kb = add_prolog_KB() | kb = add_prolog_KB() | ||||
| abd = AbducerBase(kb, 'confidence') | |||||
| abd = AbducerBase(kb, "confidence") | |||||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | ||||
| print(res) | print(res) | ||||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | ||||
| @@ -191,7 +229,7 @@ if __name__ == '__main__': | |||||
| print() | print() | ||||
| kb = add_prolog_KB() | kb = add_prolog_KB() | ||||
| abd = AbducerBase(kb, 'confidence', zoopt=True) | |||||
| abd = AbducerBase(kb, "confidence", zoopt=True) | |||||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | ||||
| print(res) | print(res) | ||||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | ||||
| @@ -205,24 +243,38 @@ if __name__ == '__main__': | |||||
| print() | print() | ||||
| kb = HWF_KB(len_list=[1, 3, 5]) | kb = HWF_KB(len_list=[1, 3, 5]) | ||||
| abd = AbducerBase(kb, 'hamming') | |||||
| res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) | |||||
| abd = AbducerBase(kb, "hamming") | |||||
| res = abd.abduce( | |||||
| (["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '2'], None, 64), max_address_num=3, require_more_address=0) | |||||
| res = abd.abduce( | |||||
| (["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) | |||||
| res = abd.abduce( | |||||
| (["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0 | |||||
| ) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) | |||||
| res = abd.abduce( | |||||
| (["5", "8", "8", "8", "8"], None, 3.17), | |||||
| max_address_num=5, | |||||
| require_more_address=3, | |||||
| ) | |||||
| print(res) | print(res) | ||||
| print() | print() | ||||
| kb = HED_prolog_KB() | kb = HED_prolog_KB() | ||||
| abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) | abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) | ||||
| consist_exs = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 0, '=', 1, 1]] | |||||
| consist_exs2 = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 1, '=', 1, 1]] # not consistent with rules | |||||
| inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] | |||||
| consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]] | |||||
| consist_exs2 = [ | |||||
| [1, "+", 0, "=", 0], | |||||
| [1, "+", 1, "=", 0], | |||||
| [0, "+", 1, "=", 1, 1], | |||||
| ] # not consistent with rules | |||||
| inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||||
| # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] | # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] | ||||
| rules = ['my_op([0], [0], [1, 1])', 'my_op([1], [1], [0])', 'my_op([1], [0], [0])'] | |||||
| rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"] | |||||
| print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) | print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) | ||||
| print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) | print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) | ||||
| @@ -236,5 +288,3 @@ if __name__ == '__main__': | |||||
| abduced_rules = abd.abduce_rules(consist_exs) | abduced_rules = abd.abduce_rules(consist_exs) | ||||
| print(abduced_rules) | print(abduced_rules) | ||||