diff --git a/abl/reasoning/base_kb.py b/abl/reasoning/base_kb.py new file mode 100644 index 0000000..560f3d4 --- /dev/null +++ b/abl/reasoning/base_kb.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple, Union + +import numpy as np + +from ..structures import ListData + + +class BaseKB(ABC): + @abstractmethod + def logic_forward(self, data_sample: ListData): + """Placeholder for the forward reasoning of the knowledge base.""" + pass + + @abstractmethod + def abduce_candidates(self, data_sample: ListData): + """Placeholder for abduction of the knowledge base.""" + pass diff --git a/abl/reasoning/ground_kb.py b/abl/reasoning/ground_kb.py new file mode 100644 index 0000000..95ee093 --- /dev/null +++ b/abl/reasoning/ground_kb.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Any, Hashable, List + +from abl.structures import ListData + +from .base_kb import BaseKB + + +class GroundKB(BaseKB, ABC): + def __init__(self, pseudo_label_list): + self.pseudo_label_list = pseudo_label_list + self.base = self.construct_base() + + @abstractmethod + def construct_base(self) -> dict: + pass + + @abstractmethod + def get_key(self, data_sample: ListData) -> Hashable: + pass + + @abstractmethod + def key2candidates(self, key: Hashable) -> List[List[Any]]: + return self.base[key] + + def filter_candidates( + self, + data_sample: ListData, + candidates: List[List[Any]], + max_revision_num: int, + require_more_revision: int = 0, + ) -> List[List[Any]]: + return candidates + + def abduce_candidates( + self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 + ): + return self._abduce_by_GKB( + data_sample=data_sample, + max_revision_num=max_revision_num, + require_more_revision=require_more_revision, + ) + + def _abduce_by_GKB( + self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 + ): + candidates = self.key2candidates(self.get_key(data_sample)) + return self.filter_candidates( + data_sample=data_sample, + max_revision_num=max_revision_num, + require_more_revision=require_more_revision, + candidates=candidates, + ) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py deleted file mode 100644 index a839ae8..0000000 --- a/abl/reasoning/kb.py +++ /dev/null @@ -1,248 +0,0 @@ -import bisect -from abc import ABC, abstractmethod -from collections import defaultdict -from functools import lru_cache -from itertools import combinations, product -from multiprocessing import Pool - -import numpy as np -import pyswip - -from ..utils.utils import (check_equal, flatten, hamming_dist, - hashable_to_list, reform_idx, to_hashable) - - -class KBBase(ABC): - def __init__(self, pseudo_label_list, max_err=0, use_cache=True): - # TODO:添加一下类型检查,比如 - # if not isinstance(X, (np.ndarray, spmatrix)): - # raise TypeError("X should be numpy array or sparse matrix") - - self.pseudo_label_list = pseudo_label_list - self.max_err = max_err - self.use_cache = use_cache - - @abstractmethod - def logic_forward(self, pseudo_labels): - pass - - def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): - if not self.use_cache: - return self._abduce_by_search( - pred_res, y, max_revision_num, require_more_revision - ) - else: - return self._abduce_by_search_cache( - to_hashable(pred_res), - to_hashable(y), - max_revision_num, - require_more_revision, - ) - - def revise_by_idx(self, pred_res, y, revision_idx): - candidates = [] - abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) - for c in abduce_c: - candidate = pred_res.copy() - for i, idx in enumerate(revision_idx): - candidate[idx] = c[i] - if check_equal(self.logic_forward(candidate), y, self.max_err): - candidates.append(candidate) - return candidates - - def _revision(self, revision_num, pred_res, y): - new_candidates = [] - revision_idx_list = combinations(range(len(pred_res)), revision_num) - - for revision_idx in revision_idx_list: - candidates = self.revise_by_idx(pred_res, y, revision_idx) - new_candidates.extend(candidates) - return new_candidates - - def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): - candidates = [] - for revision_num in range(len(pred_res) + 1): - if revision_num == 0 and check_equal( - self.logic_forward(pred_res), y, self.max_err - ): - candidates.append(pred_res) - elif revision_num > 0: - candidates.extend(self._revision(revision_num, pred_res, y)) - if len(candidates) > 0: - min_revision_num = revision_num - break - if revision_num >= max_revision_num: - return [] - - for revision_num in range( - min_revision_num + 1, min_revision_num + require_more_revision + 1 - ): - if revision_num > max_revision_num: - return candidates - candidates.extend(self._revision(revision_num, pred_res, y)) - return candidates - - @lru_cache(maxsize=None) - def _abduce_by_search_cache( - self, pred_res, y, max_revision_num, require_more_revision - ): - pred_res = hashable_to_list(pred_res) - y = hashable_to_list(y) - return self._abduce_by_search( - pred_res, y, max_revision_num, require_more_revision - ) - - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - - -class ground_KB(KBBase): - def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): - super().__init__(pseudo_label_list, max_err) - - self.GKB_len_list = GKB_len_list - self.base = {} - X, Y = self._get_GKB() - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(x) - - # For parallel version of _get_GKB - def _get_XY_list(self, args): - pre_x, post_x_it = args[0], args[1] - XY_list = [] - for post_x in post_x_it: - x = (pre_x,) + post_x - y = self.logic_forward(x) - if y is not None: - XY_list.append((x, y)) - return XY_list - - # Parallel _get_GKB - def _get_GKB(self): - X, Y = [], [] - for length in self.GKB_len_list: - print("Generating GKB of length %d" % length) - arg_list = [] - for pre_x in self.pseudo_label_list: - post_x_it = product(self.pseudo_label_list, repeat=length - 1) - arg_list.append((pre_x, post_x_it)) - with Pool(processes=len(arg_list)) as pool: - ret_list = pool.map(self._get_XY_list, arg_list) - for XY_list in ret_list: - if len(XY_list) == 0: - continue - part_X, part_Y = zip(*XY_list) - X.extend(part_X) - Y.extend(part_Y) - if Y and isinstance(Y[0], (int, float)): - X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) - return X, Y - - def abduce_candidates(self, data_sample, max_revision_num, require_more_revision=0): - return self._abduce_by_GKB( - data_sample, max_revision_num, require_more_revision=require_more_revision - ) - - def _find_candidate_GKB(self, cache_key, data_sample): - y = data_sample["Y"][0] - if self.max_err == 0: - return self.base[cache_key][y] - else: - potential_candidates = self.base[cache_key] - key_list = list(potential_candidates.keys()) - key_idx = bisect.bisect_left(key_list, y) - - all_candidates = [] - for idx in range(key_idx - 1, -1, -1): - k = key_list[idx] - if abs(k - y) <= self.max_err: - all_candidates.extend(potential_candidates[k]) - else: - break - - for idx in range(key_idx, len(key_list)): - k = key_list[idx] - if abs(k - y) <= self.max_err: - all_candidates.extend(potential_candidates[k]) - else: - break - return all_candidates - - def _abduce_by_GKB(self, data_sample, max_revision_num, require_more_revision=0): - cache_key = len(data_sample["pred_pseudo_label"][0]) - if self.base == {} or cache_key not in self.GKB_len_list: - return [] - - all_candidates = self._find_candidate_GKB(cache_key, data_sample) - if len(all_candidates) == 0: - return [] - - cost_array = hamming_dist(data_sample["pred_pseudo_label"][0], all_candidates) - min_revision_num = np.min(cost_array) - revision_num = min(max_revision_num, min_revision_num + require_more_revision) - idxs = np.where(cost_array <= revision_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - return candidates - - -class prolog_KB(KBBase): - def __init__(self, pseudo_label_list, pl_file, max_err=0): - super().__init__(pseudo_label_list, max_err) - self.prolog = pyswip.Prolog() - self.prolog.consult(pl_file) - - def logic_forward(self, pseudo_labels): - result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0][ - "Res" - ] - if result == "true": - return True - elif result == "false": - return False - return result - - def _revision_pred_res(self, pred_res, revision_idx): - import re - - revision_pred_res = pred_res.copy() - revision_pred_res = flatten(revision_pred_res) - - for idx in revision_idx: - revision_pred_res[idx] = "P" + str(idx) - revision_pred_res = reform_idx(revision_pred_res, pred_res) - - # TODO:不知道有没有更简洁的方法 - regex = r"'P\d+'" - return re.sub( - regex, lambda x: x.group().replace("'", ""), str(revision_pred_res) - ) - - def get_query_string(self, pred_res, y, revision_idx): - query_string = "logic_forward(" - query_string += self._revision_pred_res(pred_res, revision_idx) - key_is_none_flag = y is None or (type(y) == list and y[0] is None) - query_string += ",%s)." % y if not key_is_none_flag else ")." - return query_string - - def revise_by_idx(self, pred_res, y, revision_idx): - candidates = [] - query_string = self.get_query_string(pred_res, y, revision_idx) - save_pred_res = pred_res - pred_res = flatten(pred_res) - abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] - for c in abduce_c: - candidate = pred_res.copy() - for i, idx in enumerate(revision_idx): - candidate[idx] = c[i] - candidate = reform_idx(candidate, save_pred_res) - candidates.append(candidate) - return candidates diff --git a/abl/reasoning/search_based_kb.py b/abl/reasoning/search_based_kb.py new file mode 100644 index 0000000..0a37beb --- /dev/null +++ b/abl/reasoning/search_based_kb.py @@ -0,0 +1,105 @@ +from abc import ABC, abstractmethod +from itertools import combinations, product +from typing import Any, Callable, Generator, List, Optional, Tuple, Union + +import numpy + +from abl.structures import ListData + +from ..structures import ListData +from ..utils import Cache +from .base_kb import BaseKB + + +def incremental_search_strategy( + data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 +): + symbol_num = data_sample["symbol_num"] + max_revision_num = min(max_revision_num, symbol_num) + real_end = max_revision_num + for revision_num in range(max_revision_num + 1): + if revision_num > real_end: + break + + revision_idx_tuple = combinations(range(symbol_num), revision_num) + for revision_idx in revision_idx_tuple: + received = yield revision_idx + if received == "success": + real_end = min(symbol_num, revision_num + require_more_revision) + + +class SearchBasedKB(BaseKB, ABC): + def __init__( + self, + pseudo_label_list: List, + search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy, + use_cache: bool = True, + cache_root: Optional[str] = None, + ) -> None: + self.pseudo_label_list = pseudo_label_list + self.search_strategy = search_strategy + self.use_cache = use_cache + if self.use_cache and getattr(self, "get_key") is getattr(SearchBasedKB, "get_key", None): + raise NotImplementedError("If use_cache is True, get_key should be implemented.") + self.cache = Cache[ListData, List[List[Any]]]( + func=self._abduce_by_search, + cache=use_cache, + cache_root=cache_root, + key_func=lambda x: self.get_key(x), + ) + + @abstractmethod + def get_key(self, data_sample: ListData): + """ + If 'use_cache' is set to 'True', this method should be implemented. + """ + pass + + @abstractmethod + def entail(self, data_sample: ListData, y: Any): + """Placeholder for entail.""" + pass + + def abduce_candidates( + self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 + ): + return self.cache.get(data_sample, max_revision_num, require_more_revision) + + def revise_at_idx( + self, + data_sample: ListData, + revision_idx: Union[List, Tuple, numpy.ndarray], + ): + candidates = [] + abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) + for c in abduce_c: + new_data_sample = data_sample.clone() + candidate = new_data_sample["pred_pseudo_label"][0].copy() + for i, idx in enumerate(revision_idx): + candidate[idx] = c[i] + new_data_sample["pred_pseudo_label"][0] = candidate + if self.entail(new_data_sample, new_data_sample["Y"][0]): + candidates.append(candidate) + return candidates + + def _abduce_by_search( + self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 + ): + candidates = [] + gen = self.search_strategy( + data_sample, + max_revision_num=max_revision_num, + require_more_revision=require_more_revision, + ) + send_signal = True + for revision_idx in gen: + candidates.extend(self.revise_at_idx(data_sample, revision_idx)) + if len(candidates) > 0 and send_signal: + try: + revision_idx = gen.send("success") + candidates.extend(self.revise_at_idx(data_sample, revision_idx)) + send_signal = False + except StopIteration: + break + + return candidates