diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 7ac26c8..7c494c5 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -10,20 +10,6 @@ # # ================================================================# -from abc import ABC, abstractmethod -import bisect -import copy -import numpy as np - -from collections import defaultdict -from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal - -from multiprocessing import Pool - -import pyswip - - class KBBase(ABC): def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): self.pseudo_label_list = pseudo_label_list @@ -79,78 +65,16 @@ class KBBase(ABC): res.append(self.logic_forward(x)) return res - @abstractmethod - def abduce_candidates(self): - pass - - @abstractmethod - def address_by_idx(self): - pass - - def _address(self, address_num, pred_res, key, multiple_predictions): - new_candidates = [] - if not multiple_predictions: - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) - else: - address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) - - for address_idx in address_idx_list: - candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) - new_candidates += candidates - return new_candidates - - def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - candidates = [] - - for address_num in range(len(flatten(pred_res)) + 1): - if address_num == 0: - if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): - candidates.append(pred_res) - else: - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) - candidates += new_candidates - - if len(candidates) > 0: - min_address_num = address_num - break - - if address_num >= max_address_num: - return [], 0, 0 - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if address_num > max_address_num: - return candidates, min_address_num, address_num - 1 - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) - candidates += new_candidates - - return candidates, min_address_num, address_num - - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - - -class ClsKB(KBBase): - def __init__(self, pseudo_label_list, len_list, GKB_flag): - super().__init__(pseudo_label_list, len_list, GKB_flag) - - def logic_forward(self): - pass - def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) else: return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - + + @abstractmethod + def _find_candidate_GKB(self): + pass + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): if self.base == {}: return [], 0, 0 @@ -158,7 +82,7 @@ class ClsKB(KBBase): if not multiple_predictions: if len(pred_res) not in self.len_list: return [], 0, 0 - all_candidates = self.base[len(pred_res)][key] + all_candidates = self._find_candidate_GKB(pred_res, key) if len(all_candidates) == 0: return [], 0, 0 else: @@ -168,7 +92,7 @@ class ClsKB(KBBase): idxs = np.where(cost_list <= address_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates, min_address_num, address_num - + else: min_address_num = 0 all_candidates_save = [] @@ -177,7 +101,7 @@ class ClsKB(KBBase): for p_res, k in zip(pred_res, key): if len(p_res) not in self.len_list: return [], 0, 0 - all_candidates = self.base[len(p_res)][k] + all_candidates = self._regression_find_candidate_GKB(p_res, k) if len(all_candidates) == 0: return [], 0, 0 else: @@ -194,7 +118,7 @@ class ClsKB(KBBase): idxs = np.where(multiple_cost_list <= address_num)[0] candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] return candidates, min_address_num, address_num - + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) @@ -211,10 +135,71 @@ class ClsKB(KBBase): if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - if check_equal(self._logic_forward(candidate, multiple_predictions), key): + if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): candidates.append(candidate) return candidates + def _address(self, address_num, pred_res, key, multiple_predictions): + new_candidates = [] + if not multiple_predictions: + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + else: + address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) + + for address_idx in address_idx_list: + candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) + new_candidates += candidates + return new_candidates + + def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + candidates = [] + + for address_num in range(len(flatten(pred_res)) + 1): + if address_num == 0: + if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): + candidates.append(pred_res) + else: + new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + candidates += new_candidates + + if len(candidates) > 0: + min_address_num = address_num + break + + if address_num >= max_address_num: + return [], 0, 0 + + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if address_num > max_address_num: + return candidates, min_address_num, address_num - 1 + new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + candidates += new_candidates + + return candidates, min_address_num, address_num + + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) + + def __len__(self): + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) + + +class ClsKB(KBBase): + def __init__(self, pseudo_label_list, len_list, GKB_flag): + super().__init__(pseudo_label_list, len_list, GKB_flag) + + def logic_forward(self): + pass + + def _find_candidate_GKB(self, pred_res, key): + return self.base[len(pred_res)][key] + class add_KB(ClsKB): def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): @@ -232,16 +217,13 @@ class prolog_KB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - + def _find_candidate_GKB(self): + pass + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] # print(address_idx) - if not multiple_predictions: - query_string = self.get_query_string(pred_res, key, address_idx) - else: - query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) + query_string = self.get_query_string(pred_res, key, address_idx) if multiple_predictions: save_pred_res = pred_res @@ -284,21 +266,18 @@ class HED_prolog_KB(prolog_KB): super().__init__(pseudo_label_list) self.prolog.consult('./datasets/hed/learn_add.pl') - # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` def logic_forward(self, exs): return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 - def get_query_string_need_flatten(self, pred_res, key, address_idx): - # flatten + def get_query_string(self, pred_res, key, address_idx): flatten_pred_res = flatten(pred_res) # add variables for prolog for idx in range(len(flatten_pred_res)): if idx in address_idx: flatten_pred_res[idx] = 'X' + str(idx) - # unflatten - new_pred_res = reform_idx(flatten_pred_res, pred_res) + pred_res = reform_idx(flatten_pred_res, pred_res) - query_string = "abduce_consistent_insts(%s)." % new_pred_res + query_string = "abduce_consistent_insts(%s)." % pred_res return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") def consist_rule(self, exs, rules): @@ -324,13 +303,7 @@ class RegKB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): - if self.GKB_flag: - return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) - else: - return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - - def _regression_find_candidate_GKB(self, pred_res, key): + def _find_candidate_GKB(self, pred_res, key): potential_candidates = self.base[len(pred_res)] key_list = sorted(potential_candidates) key_idx = bisect.bisect_left(key_list, key) @@ -351,70 +324,6 @@ class RegKB(KBBase): break return all_candidates - def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - if self.base == {}: - return [], 0, 0 - - if not multiple_predictions: - if len(pred_res) not in self.len_list: - return [], 0, 0 - all_candidates = self._regression_find_candidate_GKB(pred_res, key) - if len(all_candidates) == 0: - return [], 0, 0 - else: - cost_list = hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - return candidates, min_address_num, address_num - - else: - min_address_num = 0 - all_candidates_save = [] - cost_list_save = [] - - for p_res, k in zip(pred_res, key): - if len(p_res) not in self.len_list: - return [], 0, 0 - all_candidates = self._regression_find_candidate_GKB(p_res, k) - if len(all_candidates) == 0: - return [], 0, 0 - else: - all_candidates_save.append(all_candidates) - cost_list = hamming_dist(p_res, all_candidates) - min_address_num += np.min(cost_list) - cost_list_save.append(cost_list) - - multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] - assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) - multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) - assert len(multiple_all_candidates) == len(multiple_cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(multiple_cost_list <= address_num)[0] - candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] - return candidates, min_address_num, address_num - - def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): - candidates = [] - abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) - - if multiple_predictions: - save_pred_res = pred_res - pred_res = flatten(pred_res) - - for c in abduce_c: - candidate = pred_res.copy() - for i, idx in enumerate(address_idx): - candidate[idx] = c[i] - - if multiple_predictions: - candidate = reform_idx(candidate, save_pred_res) - - if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): - candidates.append(candidate) - return candidates - class HWF_KB(RegKB): def __init__( @@ -456,13 +365,6 @@ class HWF_KB(RegKB): formula = [mapping[f] for f in formula] return round(eval(''.join(formula)), 2) - -import time - if __name__ == "__main__": - t1 = time.time() - kb = add_KB(GKB_flag=True) - t2 = time.time() - print(t2 - t1) - + pass