From f1b964df5889a6bcedc5b0bf837558502b7fecec Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Mon, 11 Dec 2023 15:58:40 +0800 Subject: [PATCH] [MNT] resolve several comments --- abl/reasoning/kb.py | 7 +- abl/reasoning/reasoner.py | 133 ++++++++++++++++++++++++-------------- tests/conftest.py | 10 +++ tests/test_reasoning.py | 57 +++++++++++++++- 4 files changed, 151 insertions(+), 56 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 5c5fe65..f1754c0 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -243,10 +243,7 @@ class KBBase(ABC): """ candidates = [] for revision_num in range(len(pseudo_label) + 1): - if revision_num == 0 and self._check_equal(self.logic_forward(pseudo_label, *(x,) if self._num_args == 2 else ()), y): - candidates.append(pseudo_label) - elif revision_num > 0: - candidates.extend(self._revision(revision_num, pseudo_label, y, x)) + candidates.extend(self._revision(revision_num, pseudo_label, y, x)) if len(candidates) > 0: min_revision_num = revision_num break @@ -559,7 +556,7 @@ class PrologKB(KBBase): knowledge base. """ candidates = [] - query_string = self.get_query_string(pseudo_label, y, revision_idx) + query_string = self.get_query_string(pseudo_label, y, x, revision_idx) save_pseudo_label = pseudo_label pseudo_label = flatten(pseudo_label) abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index bcc721d..1ec9458 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,8 +1,10 @@ +import inspect +from typing import Callable, Any, List, Optional + import numpy as np from zoopt import Dimension, Objective, Opt, Parameter -from typing import Callable, Any, List, Optional -from kb import KBBase +from ..reasoning import KBBase from ..structures import ListData from ..utils.utils import confidence_dist, hamming_dist @@ -16,8 +18,18 @@ class Reasoner: kb : class KBBase The knowledge base to be used for reasoning. dist_func : str or Callable, optional - The distance function to be used when determining the cost list between each - candidate and the given prediction. Defaults to "confidence". + The distance function used to determine the cost list between each + candidate and the given prediction. It can be either a string representing a + predefined distance function or a callable function. The available predefined + distance functions: 'hamming' | 'confidence'. 'hamming': directly calculates + the Hamming distance between the predicted pseudo label in the data sample + and each candidate, 'confidence': calculates the distance between the prediction + and each candidate based on confidence derived from the predicted probability + in the data sample. The callable function should have the signature + dist_func(data_sample, candidates) and must return a cost list. Each element + in this cost list should be a numerical value representing the cost for each + candidate, and the list should have the same length as candidates. + Defaults to 'confidence'. mapping : Optional[dict], optional A mapping from index in the base model to label. If not provided, a default order-based mapping is created. Defaults to None. @@ -43,6 +55,7 @@ class Reasoner: use_zoopt: bool = False, ): self.kb = kb + self._check_valid_dist(dist_func) self.dist_func = dist_func self.use_zoopt = use_zoopt self.max_revision = max_revision @@ -55,18 +68,48 @@ class Reasoner: self.mapping = mapping self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) + def _check_valid_dist(self, dist_func): + if isinstance(dist_func, str): + if dist_func not in ["hamming", "confidence"]: + raise NotImplementedError( + f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.' + ) + return + elif callable(dist_func): + params = inspect.signature(dist_func).parameters.values() + if len(params) != 2: + raise ValueError(f"User-defined dist_func must have exactly two parameters, but got {len(params)}.") + return + else: + raise TypeError( + f"dist_func must be a string or a callable function, but got {type(dist_func)}." + ) + + def _check_valid_dist_output(self, cost_list, candidate_num): + if not isinstance(cost_list, np.ndarray): + raise TypeError(f"Expected dist_func to return a numpy.ndarray, but got {type(cost_list)}.") + if not cost_list.dtype.kind in "biufc": + raise ValueError(f"Expected dist_func to return a numpy.ndarray with a numerical type, but got dtype {cost_list.dtype}.") + if len(cost_list) != candidate_num: + raise ValueError( + f"The length of the array returned by dist_func must be equal to the number of candidates. " + f"Expected length {candidate_num}, but got {len(cost_list)}." + ) + def _check_valid_mapping(self, mapping): if not isinstance(mapping, dict): - raise TypeError(f"mapping should be dict, got {type(mapping)}") + raise TypeError(f"mapping should be dict, but got {type(mapping)}.") for key, value in mapping.items(): if not isinstance(key, int): - raise ValueError(f"All keys in the mapping must be integers, got {key}") + raise ValueError(f"All keys in the mapping must be integers, but got {key}.") if value not in self.kb.pseudo_label_list: - raise ValueError(f"All values in the mapping must be in the pseudo_label_list, got {value}") - + raise ValueError( + f"All values in the mapping must be in the pseudo_label_list, but got {value}." + ) + def _get_one_candidate( - self, - data_sample: ListData, + self, + data_sample: ListData, candidates: List[List[Any]], ) -> List[Any]: """ @@ -91,25 +134,17 @@ class Reasoner: elif len(candidates) == 1: return candidates[0] else: - cost_array = self.get_cost_list(data_sample, candidates) + cost_array = self._get_cost_list(data_sample, candidates) candidate = candidates[np.argmin(cost_array)] return candidate - def get_cost_list( - self, - data_sample: ListData, + def _get_cost_list( + self, + data_sample: ListData, candidates: List[List[Any]], ) -> np.ndarray: """ - Get the list of costs between each candidate and the given data sample. - - The list is - calculated based on one of the following distance functions: - - "hamming": Directly calculates the Hamming distance between the predicted pseudo - label in the data sample and candidate. - - "confidence": Calculates the distance between the prediction and candidate based - on confidence derived from the predicted probability in the data - sample. + Get the list of costs between each candidate and the given data sample. Parameters ---------- @@ -117,7 +152,7 @@ class Reasoner: Data sample. candidates : List[List[Any]] Multiple compatible candidates. - + Returns ------- np.ndarray @@ -129,18 +164,16 @@ class Reasoner: elif self.dist_func == "confidence": candidates = [[self.remapping[x] for x in c] for c in candidates] return confidence_dist(data_sample.pred_prob, candidates) - - elif callable(self.dist_func): - return self.dist_func(data_sample, candidates) - - else: - raise ValueError("dist_func must be either a string or a callable function") + elif callable(self.dist_func): + cost_list = self.dist_func(data_sample, candidates) + self._check_valid_dist_output(cost_list, len(candidates)) + return cost_list def _zoopt_get_solution( - self, - symbol_num: int, - data_sample: ListData, + self, + symbol_num: int, + data_sample: ListData, max_revision_num: int, ) -> List[bool]: """ @@ -155,7 +188,7 @@ class Reasoner: Data sample. max_revision_num : int Specifies the maximum number of revisions allowed. - + Returns ------- List[bool] @@ -172,15 +205,15 @@ class Reasoner: return solution def zoopt_revision_score( - self, - symbol_num: int, - data_sample: ListData, + self, + symbol_num: int, + data_sample: ListData, sol: List[bool], ) -> int: """ Get the revision score for a solution. A lower score suggests that ZOOpt library has a higher preference for this solution. - + Parameters ---------- symbol_num : int @@ -189,7 +222,7 @@ class Reasoner: Data sample. sol: List[bool] The solution for ZOOpt library. - + Returns ------- int @@ -200,7 +233,7 @@ class Reasoner: data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx ) if len(candidates) > 0: - return np.min(self.get_cost_list(data_sample, candidates)) + return np.min(self._get_cost_list(data_sample, candidates)) else: return symbol_num @@ -217,17 +250,21 @@ class Reasoner: Get the maximum revision number according to input `max_revision`. """ if not isinstance(max_revision, (int, float)): - raise TypeError(f"Parameter must be of type int or float, got {type(max_revision)}") + raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}") if max_revision == -1: return symbol_num elif isinstance(max_revision, float): if not (0 <= max_revision <= 1): - raise ValueError(f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}") + raise ValueError( + f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}" + ) return round(symbol_num * max_revision) else: if max_revision < 0: - raise ValueError(f"If max_revision is an int, it must be non-negative, but got {max_revision}") + raise ValueError( + f"If max_revision is an int, it must be non-negative, but got {max_revision}" + ) return max_revision def abduce(self, data_sample: ListData) -> List[Any]: @@ -256,11 +293,11 @@ class Reasoner: ) else: candidates = self.kb.abduce_candidates( - pseudo_label = data_sample.pred_pseudo_label, - y = data_sample.Y, - x = data_sample.X, - max_revision_num = max_revision_num, - require_more_revision = self.require_more_revision, + data_sample.pred_pseudo_label, + data_sample.Y, + data_sample.X, + max_revision_num, + self.require_more_revision, ) candidate = self._get_one_candidate(data_sample, candidates) diff --git a/tests/conftest.py b/tests/conftest.py index f03c016..2590308 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,6 +82,7 @@ def data_samples_add(): ] data_samples_add = ListData() + data_samples_add.X = None data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] data_samples_add.Y = [8, 8, 17, 10] @@ -91,6 +92,7 @@ def data_samples_add(): @pytest.fixture def data_samples_hwf(): data_samples_hwf = ListData() + data_samples_hwf.X = None data_samples_hwf.pred_pseudo_label = [ ["5", "+", "2"], ["5", "+", "9"], @@ -200,6 +202,14 @@ def kb_add_prolog(): kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl") return kb +@pytest.fixture +def kb_hwf1(): + return HwfKB(max_err=0.1) + +@pytest.fixture +def kb_hwf2(): + return HwfKB(max_err=1) + @pytest.fixture def kb_hed(): diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index 4a3137f..c3c2f04 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -1,4 +1,5 @@ import pytest +import numpy as np from abl.reasoning import PrologKB, Reasoner @@ -93,15 +94,65 @@ class TestReaonser(object): def test_reasoner_init(self, reasoner_instance): assert reasoner_instance.dist_func == "confidence" - def test_invalid_dist_funce(kb_add): +class TestDistFunc(object): + def test_invalid_predefined_dist_func(self, kb_add): with pytest.raises(NotImplementedError) as excinfo: Reasoner(kb_add, "invalid_dist_func") - assert 'Valid options for dist_func include "hamming" and "confidence"' in str( + assert 'Valid options for predefined dist_func include "hamming" and "confidence"' in str( + excinfo.value + ) + + def random_dist(self, data_sample, candidates): + cost_list = np.array([np.random.rand() for _ in candidates]) + return cost_list + + def test_user_defined_dist_func(self, kb_add): + reasoner = Reasoner(kb_add, self.random_dist) + assert reasoner.dist_func == self.random_dist + + def invalid_dist1(self, candidates): + cost_list = np.array([np.random.rand() for _ in candidates]) + return cost_list + + def invalid_dist2(self, data_sample, candidates): + cost_list = np.array([np.random.rand() for _ in candidates]) + return np.append(cost_list, np.random.rand()) + + def invalid_dist3(self, data_sample, candidates): + cost_list = [np.random.rand() for _ in candidates] + return cost_list + + def invalid_dist4(self, data_sample, candidates): + cost_list = np.array(["invalid" for _ in candidates]) + return cost_list + + def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add): + with pytest.raises(ValueError) as excinfo: + Reasoner(kb_add, self.invalid_dist1) + assert 'User-defined dist_func must have exactly two parameters' in str( + excinfo.value + ) + with pytest.raises(ValueError) as excinfo: + reasoner = Reasoner(kb_add, self.invalid_dist2) + reasoner.batch_abduce(data_samples_add) + assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str( + excinfo.value + ) + with pytest.raises(TypeError) as excinfo: + reasoner = Reasoner(kb_add, self.invalid_dist3) + reasoner.batch_abduce(data_samples_add) + assert 'Expected dist_func to return a numpy.ndarray' in str( + excinfo.value + ) + with pytest.raises(ValueError) as excinfo: + reasoner = Reasoner(kb_add, self.invalid_dist4) + reasoner.batch_abduce(data_samples_add) + assert 'Expected dist_func to return a numpy.ndarray with a numerical type' in str( excinfo.value ) -class test_batch_abduce(object): +class TestBatchAbduce(object): def test_batch_abduce_add(self, kb_add, data_samples_add): reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0) reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1)