From 2951e5fe5ad3ec8f4e414d4ff26ba46f9bb624fe Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sun, 12 Nov 2023 20:43:58 +0800 Subject: [PATCH] [ENH] add search engine --- abl/reasoning/__init__.py | 1 + abl/reasoning/reasoner.py | 546 ++---------------- abl/reasoning/search_based_kb.py | 22 +- abl/reasoning/search_engine/__init__.py | 3 + .../search_engine/base_search_engine.py | 13 + abl/reasoning/search_engine/bfs.py | 28 + abl/reasoning/search_engine/zoopt.py | 42 ++ tests/test_reasoning.py | 403 +++++++++++++ 8 files changed, 535 insertions(+), 523 deletions(-) create mode 100644 abl/reasoning/search_engine/__init__.py create mode 100644 abl/reasoning/search_engine/base_search_engine.py create mode 100644 abl/reasoning/search_engine/bfs.py create mode 100644 abl/reasoning/search_engine/zoopt.py create mode 100644 tests/test_reasoning.py diff --git a/abl/reasoning/__init__.py b/abl/reasoning/__init__.py index 64d3ba4..a37c0d3 100644 --- a/abl/reasoning/__init__.py +++ b/abl/reasoning/__init__.py @@ -3,3 +3,4 @@ from .ground_kb import GroundKB from .prolog_based_kb import PrologBasedKB from .reasoner import ReasonerBase from .search_based_kb import SearchBasedKB +from .search_engine import BaseSearchEngine, BFS, Zoopt diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 7bd4944..a550581 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,11 +1,11 @@ -from typing import Any, List, Mapping, Tuple, Union +from typing import Any, List, Mapping import numpy as np -from zoopt import Dimension, Objective, Opt, Parameter, Solution from ..structures import ListData -from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist, reform_idx +from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist from .base_kb import BaseKB +from .search_engine import BaseSearchEngine, BFS class ReasonerBase: @@ -14,7 +14,7 @@ class ReasonerBase: kb: BaseKB, dist_func: str = "confidence", mapping: Mapping = None, - use_zoopt: bool = False, + search_engine: BaseSearchEngine = None, ): """ Base class for all reasoner in the ABL system. @@ -36,12 +36,14 @@ class ReasonerBase: If the specified distance function is neither "hamming" nor "confidence". """ + if not isinstance(kb, BaseKB): + raise ValueError("The kb should be of type BaseKB.") + self.kb = kb + if dist_func not in ["hamming", "confidence"]: raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.") - - self.kb = kb self.dist_func = dist_func - self.use_zoopt = use_zoopt + if mapping is None: self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} else: @@ -56,10 +58,17 @@ class ReasonerBase: raise ValueError("All values in the mapping must be in the pseudo_label_list") self.mapping = mapping - self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) - def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]): + if search_engine is None: + self.search_engine = BFS() + else: + if not isinstance(search_engine, BaseSearchEngine): + raise ValueError("The search_engine should be of type BaseSearchEngine.") + else: + self.search_engine = search_engine + + def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]): """ Get the list of costs between each pseudo label and candidate. @@ -84,7 +93,7 @@ class ReasonerBase: candidates = [[self.remapping[x] for x in c] for c in candidates] return confidence_dist(data_sample["pred_prob"][0], candidates) - def _get_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]): + def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]): """ Get one candidate. If multiple candidates exist, return the one with minimum cost. @@ -108,91 +117,10 @@ class ReasonerBase: elif len(candidates) == 1: return candidates[0] else: - cost_array = self._get_cost_list(data_sample, candidates) + cost_array = self._get_dist_list(data_sample, candidates) candidate = candidates[np.argmin(cost_array)] return candidate - def zoopt_revision_score(self, data_sample: ListData, solution: Solution): - """ - Get the revision score for a single solution. - - Parameters - ---------- - pred_pseudo_label : list - List of predicted pseudo labels. - pred_prob : list - List of probabilities for predicted results. - y : any - Ground truth for the predicted results. - solution : array-like - Solution to evaluate. - - Returns - ------- - float - The revision score for the given solution. - """ - revision_idx = np.where(solution.get_x() != 0)[0] - candidates = self.revise_at_idx(data_sample, revision_idx) - if len(candidates) > 0: - return np.min(self._get_cost_list(data_sample, candidates)) - else: - return data_sample["symbol_num"] - - def _constrain_revision_num(self, solution: Solution, max_revision_num: int): - x = solution.get_x() - return max_revision_num - x.sum() - - def zoopt_get_solution(self, data_sample: ListData, max_revision_num: int): - """Get the optimal solution using the Zoopt library. - - Parameters - ---------- - pred_pseudo_label : list - List of predicted pseudo labels. - pred_prob : list - List of probabilities for predicted results. - y : any - Ground truth for the predicted results. - max_revision_num : int - Maximum number of revisions to use. - - Returns - ------- - array-like - The optimal solution, i.e., where to revise predict pseudo label. - """ - symbol_num = data_sample["symbol_num"] - dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) - objective = Objective( - lambda solution: self.zoopt_revision_score(data_sample, solution), - dim=dimension, - constraint=lambda solution: self._constrain_revision_num(solution, max_revision_num), - ) - parameter = Parameter(budget=100, intermediate_result=False, autoset=True) - solution = Opt.min(objective, parameter).get_x() - return solution - - def revise_at_idx(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray]): - """ - Revise the pseudo label according to the given indices. - - Parameters - ---------- - pred_pseudo_label : list - List of predicted pseudo labels. - y : any - Ground truth for the predicted results. - revision_idx : array-like - Indices of the revisions to retrieve. - - Returns - ------- - list - The revisions according to the given indices. - """ - return self.kb.revise_at_idx(data_sample, revision_idx) - def abduce( self, data_sample: ListData, @@ -223,19 +151,35 @@ class ReasonerBase: """ symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = calculate_revision_num(max_revision, symbol_num) - data_sample.set_metainfo(dict(symbol_num=symbol_num)) - if self.use_zoopt: - solution = self.zoopt_get_solution(data_sample, max_revision_num) - revision_idx = np.where(solution != 0)[0] - candidates = self.revise_at_idx(data_sample, revision_idx) - else: + if hasattr(self.kb, "abduce_candidates"): candidates = self.kb.abduce_candidates( data_sample, max_revision_num, require_more_revision ) + elif hasattr(self.kb, "revise_at_idx"): + candidates = [] + gen = self.search_engine.generator( + data_sample, + max_revision_num=max_revision_num, + require_more_revision=require_more_revision, + ) + send_signal = True + for revision_idx in gen: + candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx)) + if len(candidates) > 0 and send_signal: + try: + revision_idx = gen.send("success") + candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx)) + send_signal = False + except StopIteration: + break + else: + raise NotImplementedError( + "The kb should either implement abduce_candidates or revise_at_idx." + ) - candidate = self._get_one_candidate(data_sample, candidates) + candidate = self.select_one_candidate(data_sample, candidates) return candidate def batch_abduce( @@ -285,407 +229,3 @@ class ReasonerBase: # with Pool(processes=os.cpu_count()) as pool: # results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) # return results - - -if __name__ == "__main__": - from abl.reasoning.base_kb import BaseKB, GroundKB, PrologBasedKB - - prob1 = [ - [ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ] - ] - - prob2 = [ - [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ] - ] - - class add_KB(BaseKB): - def __init__(self, pseudo_label_list=list(range(10)), use_cache=True): - super().__init__(pseudo_label_list, use_cache=use_cache) - - def logic_forward(self, nums): - return sum(nums) - - class add_GroundKB(GroundKB): - def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): - super().__init__(pseudo_label_list, GKB_len_list) - - def logic_forward(self, nums): - return sum(nums) - - def test_add(reasoner): - res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) - print(res) - print() - - print("add_KB with GKB:") - kb = add_GroundKB() - reasoner = ReasonerBase(kb, "confidence") - test_add(reasoner) - - print("add_KB without GKB:") - kb = add_KB() - reasoner = ReasonerBase(kb, "confidence") - test_add(reasoner) - - print("add_KB without GKB, no cache") - kb = add_KB(use_cache=False) - reasoner = ReasonerBase(kb, "confidence") - test_add(reasoner) - - print("PrologBasedKB with add.pl:") - kb = PrologBasedKB( - pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl" - ) - reasoner = ReasonerBase(kb, "confidence") - test_add(reasoner) - - print("PrologBasedKB with add.pl using zoopt:") - kb = PrologBasedKB( - pseudo_label_list=list(range(10)), - pl_file="examples/mnist_add/datasets/add.pl", - ) - reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) - test_add(reasoner) - - print("add_KB with multiple inputs at once:") - multiple_prob = [ - [ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ], - [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ], - ] - - kb = add_KB() - reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=1, - ) - print(res) - print() - - class HWF_KB(BaseKB): - def __init__( - self, - pseudo_label_list=[ - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "+", - "-", - "times", - "div", - ], - max_err=1e-3, - ): - super().__init__(pseudo_label_list, max_err) - - def _valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in [ - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - ]: - return False - if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: - return False - return True - - def logic_forward(self, formula): - if not self._valid_candidate(formula): - return np.inf - mapping = {str(i): str(i) for i in range(1, 10)} - mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) - formula = [mapping[f] for f in formula] - return eval("".join(formula)) - - class HWF_GroundKB(GroundKB): - def __init__( - self, - pseudo_label_list=[ - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "+", - "-", - "times", - "div", - ], - GKB_len_list=[1, 3, 5, 7], - max_err=1e-3, - ): - super().__init__(pseudo_label_list, GKB_len_list, max_err) - - def _valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in [ - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - ]: - return False - if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: - return False - return True - - def logic_forward(self, formula): - if not self._valid_candidate(formula): - return np.inf - mapping = {str(i): str(i) for i in range(1, 10)} - mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) - formula = [mapping[f] for f in formula] - return eval("".join(formula)) - - def test_hwf(reasoner): - res = reasoner.batch_abduce( - [None], - [["5", "+", "2"]], - [3], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None], - [["5", "+", "9"]], - [65], - max_revision=3, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None], - [["5", "8", "8", "8", "8"]], - [3.17], - max_revision=5, - require_more_revision=3, - ) - print(res) - print() - - def test_hwf_multiple(reasoner, max_revisions): - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[0], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[1], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 65], - max_revision=max_revisions[2], - require_more_revision=0, - ) - print(res) - print() - - print("HWF_KB with GKB, max_err=0.1") - kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf(reasoner) - - print("HWF_KB without GKB, max_err=0.1") - kb = HWF_KB(max_err=0.1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf(reasoner) - - print("HWF_KB with GKB, max_err=1") - kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf(reasoner) - - print("HWF_KB without GKB, max_err=1") - kb = HWF_KB(max_err=1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf(reasoner) - - print("HWF_KB with multiple inputs at once:") - kb = HWF_KB(max_err=0.1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf_multiple(reasoner, max_revisions=[1, 3, 3]) - - print("max_revision is float") - test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9]) - - class HED_prolog_KB(PrologBasedKB): - def __init__(self, pseudo_label_list, pl_file): - super().__init__(pseudo_label_list, pl_file) - - def consist_rule(self, exs, rules): - rules = str(rules).replace("'", "") - pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) - return len(list(self.prolog.query(pl_query))) != 0 - - def abduce_rules(self, pred_res): - pl_query = "consistent_inst_feature(%s, X)." % pred_res - prolog_result = list(self.prolog.query(pl_query)) - if len(prolog_result) == 0: - return None - prolog_rules = prolog_result[0]["X"] - rules = [rule.value for rule in prolog_rules] - return rules - - class HED_Reasoner(ReasonerBase): - def __init__(self, kb, dist_func="hamming"): - super().__init__(kb, dist_func, use_zoopt=True) - - def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs): - pred = [] - k = [] - revision_flag = [] - for idx in idxs: - pred.append(pred_res[idx]) - k.append(y[idx]) - revision_flag += list(all_revision_flag[idx]) - revision_idx = np.where(np.array(revision_flag) != 0)[0] - candidate = self.revise_at_idx(pred, k, revision_idx) - return candidate - - def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): - all_revision_flag = reform_idx(sol.get_x(), pred_res) - lefted_idxs = [i for i in range(len(pred_res))] - candidate_size = [] - while lefted_idxs: - idxs = [] - idxs.append(lefted_idxs.pop(0)) - max_candidate_idxs = [] - found = False - for idx in range(-1, len(pred_res)): - if (not idx in idxs) and (idx >= 0): - idxs.append(idx) - candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs) - if len(candidate) == 0: - if len(idxs) > 1: - idxs.pop() - else: - if len(idxs) > len(max_candidate_idxs): - found = True - max_candidate_idxs = idxs.copy() - removed = [i for i in lefted_idxs if i in max_candidate_idxs] - if found: - candidate_size.append(len(removed) + 1) - lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] - candidate_size.sort() - score = 0 - import math - - for i in range(0, len(candidate_size)): - score -= math.exp(-i) * candidate_size[i] - return score - - def abduce_rules(self, pred_res): - return self.kb.abduce_rules(pred_res) - - kb = HED_prolog_KB( - pseudo_label_list=[1, 0, "+", "="], - pl_file="examples/hed/datasets/learn_add.pl", - ) - reasoner = HED_Reasoner(kb) - consist_exs = [ - [1, 1, "+", 0, "=", 1, 1], - [1, "+", 1, "=", 1, 0], - [0, "+", 0, "=", 0], - ] - inconsist_exs1 = [ - [1, 1, "+", 0, "=", 1, 1], - [1, "+", 1, "=", 1, 0], - [0, "+", 0, "=", 0], - [0, "+", 0, "=", 1], - ] - inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] - rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] - - print("HED_kb logic forward") - print(kb.logic_forward(consist_exs)) - print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) - print() - print("HED_kb consist rule") - print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) - print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) - print() - - print("HED_Reasoner abduce") - res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) - ) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) - ) - print(res) - print() - - print("HED_Reasoner abduce rules") - abduced_rules = reasoner.abduce_rules(consist_exs) - print(abduced_rules) diff --git a/abl/reasoning/search_based_kb.py b/abl/reasoning/search_based_kb.py index cd5db40..3f527eb 100644 --- a/abl/reasoning/search_based_kb.py +++ b/abl/reasoning/search_based_kb.py @@ -4,28 +4,10 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union import numpy -from abl.structures import ListData - from ..structures import ListData from ..utils import Cache from .base_kb import BaseKB - - -def incremental_search_strategy( - data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 -): - symbol_num = data_sample["symbol_num"] - max_revision_num = min(max_revision_num, symbol_num) - real_end = max_revision_num - for revision_num in range(max_revision_num + 1): - if revision_num > real_end: - break - - revision_idx_tuple = combinations(range(symbol_num), revision_num) - for revision_idx in revision_idx_tuple: - received = yield revision_idx - if received == "success": - real_end = min(symbol_num, revision_num + require_more_revision) +from .search_engine import incremental_search_strategy class SearchBasedKB(BaseKB, ABC): @@ -35,7 +17,7 @@ class SearchBasedKB(BaseKB, ABC): search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy, use_cache: bool = True, cache_file: Optional[str] = None, - cache_size: int = 4096 + cache_size: int = 4096, ) -> None: super().__init__(pseudo_label_list) self.search_strategy = search_strategy diff --git a/abl/reasoning/search_engine/__init__.py b/abl/reasoning/search_engine/__init__.py new file mode 100644 index 0000000..45f5442 --- /dev/null +++ b/abl/reasoning/search_engine/__init__.py @@ -0,0 +1,3 @@ +from .base_search_engine import BaseSearchEngine +from .bfs import BFS +from .zoopt import Zoopt diff --git a/abl/reasoning/search_engine/base_search_engine.py b/abl/reasoning/search_engine/base_search_engine.py new file mode 100644 index 0000000..09a6dff --- /dev/null +++ b/abl/reasoning/search_engine/base_search_engine.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple, Union + +import numpy + +from ...structures import ListData + + +class BaseSearchEngine(ABC): + @abstractmethod + def generator(data_sample: ListData) -> Union[List, Tuple, numpy.ndarray]: + """Placeholder for the generator of revision_idx.""" + pass diff --git a/abl/reasoning/search_engine/bfs.py b/abl/reasoning/search_engine/bfs.py new file mode 100644 index 0000000..104470a --- /dev/null +++ b/abl/reasoning/search_engine/bfs.py @@ -0,0 +1,28 @@ +from itertools import combinations +from typing import List, Tuple, Union + +import numpy + +from ...structures import ListData +from .base_search_engine import BaseSearchEngine + + +class BFS(BaseSearchEngine): + def __init__(self) -> None: + pass + + def generator( + data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 + ) -> Union[List, Tuple, numpy.ndarray]: + symbol_num = data_sample["symbol_num"] + max_revision_num = min(max_revision_num, symbol_num) + real_end = max_revision_num + for revision_num in range(max_revision_num + 1): + if revision_num > real_end: + break + + revision_idx_tuple = combinations(range(symbol_num), revision_num) + for revision_idx in revision_idx_tuple: + received = yield revision_idx + if received == "success": + real_end = min(symbol_num, revision_num + require_more_revision) diff --git a/abl/reasoning/search_engine/zoopt.py b/abl/reasoning/search_engine/zoopt.py new file mode 100644 index 0000000..d653f46 --- /dev/null +++ b/abl/reasoning/search_engine/zoopt.py @@ -0,0 +1,42 @@ +from typing import List, Tuple, Union + +import numpy as np +from zoopt import Dimension, Objective, Opt, Parameter, Solution + +from ...structures import ListData +from ..reasoner import ReasonerBase +from ..search_based_kb import SearchBasedKB +from .base_search_engine import BaseSearchEngine + + +class Zoopt(BaseSearchEngine): + def __init__(self, reasoner: ReasonerBase, kb: SearchBasedKB) -> None: + self.reasoner = reasoner + self.kb = kb + + def score_func(self, data_sample: ListData, solution: Solution): + revision_idx = np.where(solution.get_x() != 0)[0] + candidates = self.kb.revise_at_idx(data_sample, revision_idx) + if len(candidates) > 0: + return np.min(self.reasoner._get_dist_list(data_sample, candidates)) + else: + return data_sample["symbol_num"] + + @staticmethod + def constraint(solution: Solution, max_revision_num: int): + x = solution.get_x() + return max_revision_num - x.sum() + + def generator( + self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 + ) -> Union[List, Tuple, np.ndarray]: + symbol_num = data_sample["symbol_num"] + dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) + objective = Objective( + lambda solution: self.score_func(self, data_sample, solution), + dim=dimension, + constraint=lambda solution: self.constraint(solution, max_revision_num), + ) + parameter = Parameter(budget=100, intermediate_result=False, autoset=True) + solution = Opt.min(objective, parameter).get_x() + yield solution diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py new file mode 100644 index 0000000..8bb7955 --- /dev/null +++ b/tests/test_reasoning.py @@ -0,0 +1,403 @@ + +from abl.reasoning import ReasonerBase, BaseKB, GroundKB, PrologBasedKB + +if __name__ == "__main__": + prob1 = [ + [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + ] + + prob2 = [ + [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + ] + + class add_KB(BaseKB): + def __init__(self, pseudo_label_list=list(range(10)), use_cache=True): + super().__init__(pseudo_label_list, use_cache=use_cache) + + def logic_forward(self, nums): + return sum(nums) + + class add_GroundKB(GroundKB): + def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): + super().__init__(pseudo_label_list, GKB_len_list) + + def logic_forward(self, nums): + return sum(nums) + + def test_add(reasoner): + res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + print(res) + print() + + print("add_KB with GKB:") + kb = add_GroundKB() + reasoner = ReasonerBase(kb, "confidence") + test_add(reasoner) + + print("add_KB without GKB:") + kb = add_KB() + reasoner = ReasonerBase(kb, "confidence") + test_add(reasoner) + + print("add_KB without GKB, no cache") + kb = add_KB(use_cache=False) + reasoner = ReasonerBase(kb, "confidence") + test_add(reasoner) + + print("PrologBasedKB with add.pl:") + kb = PrologBasedKB( + pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl" + ) + reasoner = ReasonerBase(kb, "confidence") + test_add(reasoner) + + print("PrologBasedKB with add.pl using zoopt:") + kb = PrologBasedKB( + pseudo_label_list=list(range(10)), + pl_file="examples/mnist_add/datasets/add.pl", + ) + reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) + test_add(reasoner) + + print("add_KB with multiple inputs at once:") + multiple_prob = [ + [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ], + [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ], + ] + + kb = add_KB() + reasoner = ReasonerBase(kb, "confidence") + res = reasoner.batch_abduce( + multiple_prob, + [[1, 1], [1, 2]], + [4, 8], + max_revision=2, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + multiple_prob, + [[1, 1], [1, 2]], + [4, 8], + max_revision=2, + require_more_revision=1, + ) + print(res) + print() + + class HWF_KB(BaseKB): + def __init__( + self, + pseudo_label_list=[ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "+", + "-", + "times", + "div", + ], + max_err=1e-3, + ): + super().__init__(pseudo_label_list, max_err) + + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ]: + return False + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: + return False + return True + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return np.inf + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) + formula = [mapping[f] for f in formula] + return eval("".join(formula)) + + class HWF_GroundKB(GroundKB): + def __init__( + self, + pseudo_label_list=[ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "+", + "-", + "times", + "div", + ], + GKB_len_list=[1, 3, 5, 7], + max_err=1e-3, + ): + super().__init__(pseudo_label_list, GKB_len_list, max_err) + + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ]: + return False + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: + return False + return True + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return np.inf + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) + formula = [mapping[f] for f in formula] + return eval("".join(formula)) + + def test_hwf(reasoner): + res = reasoner.batch_abduce( + [None], + [["5", "+", "2"]], + [3], + max_revision=2, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None], + [["5", "+", "9"]], + [65], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None], + [["5", "8", "8", "8", "8"]], + [3.17], + max_revision=5, + require_more_revision=3, + ) + print(res) + print() + + def test_hwf_multiple(reasoner, max_revisions): + res = reasoner.batch_abduce( + [None, None], + [["5", "+", "2"], ["5", "+", "9"]], + [3, 64], + max_revision=max_revisions[0], + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None, None], + [["5", "+", "2"], ["5", "+", "9"]], + [3, 64], + max_revision=max_revisions[1], + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None, None], + [["5", "+", "2"], ["5", "+", "9"]], + [3, 65], + max_revision=max_revisions[2], + require_more_revision=0, + ) + print(res) + print() + + print("HWF_KB with GKB, max_err=0.1") + kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) + reasoner = ReasonerBase(kb, "hamming") + test_hwf(reasoner) + + print("HWF_KB without GKB, max_err=0.1") + kb = HWF_KB(max_err=0.1) + reasoner = ReasonerBase(kb, "hamming") + test_hwf(reasoner) + + print("HWF_KB with GKB, max_err=1") + kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1) + reasoner = ReasonerBase(kb, "hamming") + test_hwf(reasoner) + + print("HWF_KB without GKB, max_err=1") + kb = HWF_KB(max_err=1) + reasoner = ReasonerBase(kb, "hamming") + test_hwf(reasoner) + + print("HWF_KB with multiple inputs at once:") + kb = HWF_KB(max_err=0.1) + reasoner = ReasonerBase(kb, "hamming") + test_hwf_multiple(reasoner, max_revisions=[1, 3, 3]) + + print("max_revision is float") + test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9]) + + class HED_prolog_KB(PrologBasedKB): + def __init__(self, pseudo_label_list, pl_file): + super().__init__(pseudo_label_list, pl_file) + + def consist_rule(self, exs, rules): + rules = str(rules).replace("'", "") + pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) + return len(list(self.prolog.query(pl_query))) != 0 + + def abduce_rules(self, pred_res): + pl_query = "consistent_inst_feature(%s, X)." % pred_res + prolog_result = list(self.prolog.query(pl_query)) + if len(prolog_result) == 0: + return None + prolog_rules = prolog_result[0]["X"] + rules = [rule.value for rule in prolog_rules] + return rules + + class HED_Reasoner(ReasonerBase): + def __init__(self, kb, dist_func="hamming"): + super().__init__(kb, dist_func, use_zoopt=True) + + def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs): + pred = [] + k = [] + revision_flag = [] + for idx in idxs: + pred.append(pred_res[idx]) + k.append(y[idx]) + revision_flag += list(all_revision_flag[idx]) + revision_idx = np.where(np.array(revision_flag) != 0)[0] + candidate = self.revise_at_idx(pred, k, revision_idx) + return candidate + + def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): + all_revision_flag = reform_idx(sol.get_x(), pred_res) + lefted_idxs = [i for i in range(len(pred_res))] + candidate_size = [] + while lefted_idxs: + idxs = [] + idxs.append(lefted_idxs.pop(0)) + max_candidate_idxs = [] + found = False + for idx in range(-1, len(pred_res)): + if (not idx in idxs) and (idx >= 0): + idxs.append(idx) + candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs) + if len(candidate) == 0: + if len(idxs) > 1: + idxs.pop() + else: + if len(idxs) > len(max_candidate_idxs): + found = True + max_candidate_idxs = idxs.copy() + removed = [i for i in lefted_idxs if i in max_candidate_idxs] + if found: + candidate_size.append(len(removed) + 1) + lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] + candidate_size.sort() + score = 0 + import math + + for i in range(0, len(candidate_size)): + score -= math.exp(-i) * candidate_size[i] + return score + + def abduce_rules(self, pred_res): + return self.kb.abduce_rules(pred_res) + + kb = HED_prolog_KB( + pseudo_label_list=[1, 0, "+", "="], + pl_file="examples/hed/datasets/learn_add.pl", + ) + reasoner = HED_Reasoner(kb) + consist_exs = [ + [1, 1, "+", 0, "=", 1, 1], + [1, "+", 1, "=", 1, 0], + [0, "+", 0, "=", 0], + ] + inconsist_exs1 = [ + [1, 1, "+", 0, "=", 1, 1], + [1, "+", 1, "=", 1, 0], + [0, "+", 0, "=", 0], + [0, "+", 0, "=", 1], + ] + inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] + rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] + + print("HED_kb logic forward") + print(kb.logic_forward(consist_exs)) + print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) + print() + print("HED_kb consist rule") + print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) + print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) + print() + + print("HED_Reasoner abduce") + res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)) + print(res) + res = reasoner.abduce( + [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) + ) + print(res) + res = reasoner.abduce( + [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) + ) + print(res) + print() + + print("HED_Reasoner abduce rules") + abduced_rules = reasoner.abduce_rules(consist_exs) + print(abduced_rules)