diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index b7e6526..620a4ce 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -13,7 +13,7 @@ import abc import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist +from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, nested_length class AbducerBase(abc.ABC): def __init__(self, kb, dist_func='hamming', zoopt=False): @@ -52,14 +52,14 @@ class AbducerBase(abc.ABC): return len(pred_res) def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): - # if not self.multiple_predictions: - return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key) - # else: - # all_address_flag = reform_idx(sol.get_x(), pred_res) - # score = 0 - # for idx in range(len(pred_res)): - # score += self._zoopt_address_score_single(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key) - # return score + all_address_flag = reform_idx(sol.get_x(), pred_res) + if nested_length(pred_res) == 1: + return self._zoopt_address_score_single(all_address_flag[idx], pred_res, pred_res_prob, key) + else: + score = 0 + for idx in range(nested_length(pred_res)): + score += self._zoopt_address_score_single(all_address_flag[idx], [pred_res[idx]], [pred_res_prob[idx]], [key[idx]]) + return score def _constrain_address_num(self, solution, max_address_num): x = solution.get_x() @@ -78,19 +78,18 @@ class AbducerBase(abc.ABC): return solution def address_by_idx(self, pred_res, key, address_idx): + # print(pred_res, address_idx) return self.kb.address_by_idx(pred_res, key, address_idx) def abduce(self, data, max_address=-1, require_more_address=0): pred_res, pred_res_prob, key = data - # if max_address_num == -1: - # max_address_num = len(flatten(pred_res)) assert(type(max_address) in (int, float)) if max_address == -1: - max_address_num = len(pred_res) + max_address_num = len(flatten(pred_res)) elif type(max_address) == float: assert(max_address >= 0 and max_address <= 1) - max_address_num = round(len(pred_res) * max_address) + max_address_num = round(len(flatten(pred_res)) * max_address) else: assert(max_address >= 0) max_address_num = max_address @@ -267,11 +266,11 @@ if __name__ == '__main__': print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) print() - # res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs))) - # print(res) - # res = abd.batch_abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs))) - # print(res) - # print() + res = abd.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))) + print(res) + res = abd.abduce((inconsist_exs, [[[None]]] * len(consist_exs), [None] * len(inconsist_exs))) + print(res) + print() - # abduced_rules = abd.batch_abduce_rules(consist_exs) - # print(abduced_rules) \ No newline at end of file + abduced_rules = abd.abduce_rules(consist_exs) + print(abduced_rules) \ No newline at end of file diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index f7d8091..1421a97 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -109,16 +109,10 @@ class KBBase(ABC): def address_by_idx(self, pred_res, key, address_idx): candidates = [] abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) - # if multiple_predictions: - # save_pred_res = pred_res - # pred_res = flatten(pred_res) - for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] - # if multiple_predictions: - # candidate = reform_idx(candidate, save_pred_res) if check_equal(self.logic_forward(candidate), key, self.max_err): candidates.append(candidate) return candidates @@ -139,7 +133,7 @@ class KBBase(ABC): key = hashable_to_list(key) candidates = [] - for address_num in range(len(flatten(pred_res)) + 1): + for address_num in range(len(pred_res) + 1): if address_num == 0: if check_equal(self.logic_forward(pred_res), key, self.max_err): candidates.append(pred_res) @@ -202,24 +196,22 @@ class prolog_KB(KBBase): return False return result - def _address_pred_res(self, pred_res, address_idx, multiple_predictions): + def _address_pred_res(self, pred_res, address_idx): import re address_pred_res = pred_res.copy() - if multiple_predictions: - address_pred_res = flatten(address_pred_res) + address_pred_res = flatten(address_pred_res) for idx in address_idx: address_pred_res[idx] = 'P' + str(idx) - if multiple_predictions: - address_pred_res = reform_idx(address_pred_res, pred_res) + address_pred_res = reform_idx(address_pred_res, pred_res) # TODO:不知道有没有更简洁的方法 regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) - def get_query_string(self, pred_res, key, address_idx, multiple_predictions): + def get_query_string(self, pred_res, key, address_idx): query_string = "logic_forward(" - query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions) + query_string += self._address_pred_res(pred_res, address_idx) key_is_none_flag = key is None or (type(key) == list and key[0] is None) query_string += ",%s)." % key if not key_is_none_flag else ")." return query_string @@ -227,19 +219,17 @@ class prolog_KB(KBBase): def _find_candidate_GKB(self, pred_res, key): pass - def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): + def address_by_idx(self, pred_res, key, address_idx): candidates = [] - query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions) - if multiple_predictions: - save_pred_res = pred_res - pred_res = flatten(pred_res) + query_string = self.get_query_string(pred_res, key, address_idx) + save_pred_res = pred_res + pred_res = flatten(pred_res) abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] - if multiple_predictions: - candidate = reform_idx(candidate, save_pred_res) + candidate = reform_idx(candidate, save_pred_res) candidates.append(candidate) return candidates diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 90376aa..a0d5cd1 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -3,14 +3,20 @@ from .plog import INFO from collections import OrderedDict from itertools import chain -# for multiple predictions +def nested_length(l): + if not isinstance(l[0], (list, tuple)): + return 1 + return len(l) + def flatten(l): if not isinstance(l[0], (list, tuple)): return l return list(chain.from_iterable(l)) -# for multiple predictions def reform_idx(flatten_pred_res, save_pred_res): + if not isinstance(save_pred_res[0], (list, tuple)): + return flatten_pred_res + re = [] i = 0 for e in save_pred_res: