| @@ -13,7 +13,7 @@ | |||
| import abc | |||
| import numpy as np | |||
| from zoopt import Dimension, Objective, Parameter, Opt | |||
| from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist | |||
| from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, nested_length | |||
| class AbducerBase(abc.ABC): | |||
| def __init__(self, kb, dist_func='hamming', zoopt=False): | |||
| @@ -52,14 +52,14 @@ class AbducerBase(abc.ABC): | |||
| return len(pred_res) | |||
| def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): | |||
| # if not self.multiple_predictions: | |||
| return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key) | |||
| # else: | |||
| # all_address_flag = reform_idx(sol.get_x(), pred_res) | |||
| # score = 0 | |||
| # for idx in range(len(pred_res)): | |||
| # score += self._zoopt_address_score_single(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key) | |||
| # return score | |||
| all_address_flag = reform_idx(sol.get_x(), pred_res) | |||
| if nested_length(pred_res) == 1: | |||
| return self._zoopt_address_score_single(all_address_flag[idx], pred_res, pred_res_prob, key) | |||
| else: | |||
| score = 0 | |||
| for idx in range(nested_length(pred_res)): | |||
| score += self._zoopt_address_score_single(all_address_flag[idx], [pred_res[idx]], [pred_res_prob[idx]], [key[idx]]) | |||
| return score | |||
| def _constrain_address_num(self, solution, max_address_num): | |||
| x = solution.get_x() | |||
| @@ -78,19 +78,18 @@ class AbducerBase(abc.ABC): | |||
| return solution | |||
| def address_by_idx(self, pred_res, key, address_idx): | |||
| # print(pred_res, address_idx) | |||
| return self.kb.address_by_idx(pred_res, key, address_idx) | |||
| def abduce(self, data, max_address=-1, require_more_address=0): | |||
| pred_res, pred_res_prob, key = data | |||
| # if max_address_num == -1: | |||
| # max_address_num = len(flatten(pred_res)) | |||
| assert(type(max_address) in (int, float)) | |||
| if max_address == -1: | |||
| max_address_num = len(pred_res) | |||
| max_address_num = len(flatten(pred_res)) | |||
| elif type(max_address) == float: | |||
| assert(max_address >= 0 and max_address <= 1) | |||
| max_address_num = round(len(pred_res) * max_address) | |||
| max_address_num = round(len(flatten(pred_res)) * max_address) | |||
| else: | |||
| assert(max_address >= 0) | |||
| max_address_num = max_address | |||
| @@ -267,11 +266,11 @@ if __name__ == '__main__': | |||
| print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) | |||
| print() | |||
| # res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs))) | |||
| # print(res) | |||
| # res = abd.batch_abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs))) | |||
| # print(res) | |||
| # print() | |||
| res = abd.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))) | |||
| print(res) | |||
| res = abd.abduce((inconsist_exs, [[[None]]] * len(consist_exs), [None] * len(inconsist_exs))) | |||
| print(res) | |||
| print() | |||
| # abduced_rules = abd.batch_abduce_rules(consist_exs) | |||
| # print(abduced_rules) | |||
| abduced_rules = abd.abduce_rules(consist_exs) | |||
| print(abduced_rules) | |||
| @@ -109,16 +109,10 @@ class KBBase(ABC): | |||
| def address_by_idx(self, pred_res, key, address_idx): | |||
| candidates = [] | |||
| abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) | |||
| # if multiple_predictions: | |||
| # save_pred_res = pred_res | |||
| # pred_res = flatten(pred_res) | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| # if multiple_predictions: | |||
| # candidate = reform_idx(candidate, save_pred_res) | |||
| if check_equal(self.logic_forward(candidate), key, self.max_err): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| @@ -139,7 +133,7 @@ class KBBase(ABC): | |||
| key = hashable_to_list(key) | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| @@ -202,24 +196,22 @@ class prolog_KB(KBBase): | |||
| return False | |||
| return result | |||
| def _address_pred_res(self, pred_res, address_idx, multiple_predictions): | |||
| def _address_pred_res(self, pred_res, address_idx): | |||
| import re | |||
| address_pred_res = pred_res.copy() | |||
| if multiple_predictions: | |||
| address_pred_res = flatten(address_pred_res) | |||
| address_pred_res = flatten(address_pred_res) | |||
| for idx in address_idx: | |||
| address_pred_res[idx] = 'P' + str(idx) | |||
| if multiple_predictions: | |||
| address_pred_res = reform_idx(address_pred_res, pred_res) | |||
| address_pred_res = reform_idx(address_pred_res, pred_res) | |||
| # TODO:不知道有没有更简洁的方法 | |||
| regex = r"'P\d+'" | |||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) | |||
| def get_query_string(self, pred_res, key, address_idx, multiple_predictions): | |||
| def get_query_string(self, pred_res, key, address_idx): | |||
| query_string = "logic_forward(" | |||
| query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions) | |||
| query_string += self._address_pred_res(pred_res, address_idx) | |||
| key_is_none_flag = key is None or (type(key) == list and key[0] is None) | |||
| query_string += ",%s)." % key if not key_is_none_flag else ")." | |||
| return query_string | |||
| @@ -227,19 +219,17 @@ class prolog_KB(KBBase): | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| pass | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| def address_by_idx(self, pred_res, key, address_idx): | |||
| candidates = [] | |||
| query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = flatten(pred_res) | |||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||
| save_pred_res = pred_res | |||
| pred_res = flatten(pred_res) | |||
| abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| candidates.append(candidate) | |||
| return candidates | |||
| @@ -3,14 +3,20 @@ from .plog import INFO | |||
| from collections import OrderedDict | |||
| from itertools import chain | |||
| # for multiple predictions | |||
| def nested_length(l): | |||
| if not isinstance(l[0], (list, tuple)): | |||
| return 1 | |||
| return len(l) | |||
| def flatten(l): | |||
| if not isinstance(l[0], (list, tuple)): | |||
| return l | |||
| return list(chain.from_iterable(l)) | |||
| # for multiple predictions | |||
| def reform_idx(flatten_pred_res, save_pred_res): | |||
| if not isinstance(save_pred_res[0], (list, tuple)): | |||
| return flatten_pred_res | |||
| re = [] | |||
| i = 0 | |||
| for e in save_pred_res: | |||