From bbc9ef46bbf37f72f0aae5784c99ee8cb15fd2c2 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Sat, 25 Nov 2023 22:56:00 +0800 Subject: [PATCH] [ENH] add reasoning test --- tests/conftest.py | 109 ++++++++++++++++++++++++++++- tests/test_reasoning.py | 147 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 tests/test_reasoning.py diff --git a/tests/conftest.py b/tests/conftest.py index fac5523..fd427ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,10 +4,10 @@ import torch.nn as nn import torch.optim as optim from abl.learning import BasicNN +from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase from abl.structures import ListData from examples.models.nn import LeNet5 - # Fixture for BasicNN instance @pytest.fixture def basic_nn_instance(): @@ -32,3 +32,110 @@ def list_data_instance(): data_samples.Y = [1, 2, 3] data_samples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] return data_samples + + +@pytest.fixture +def data_samples_add(): + # favor 1 in first one + prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + # favor 7 in first one + prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + + data_samples_add = ListData() + data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] + data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] + data_samples_add.Y = [8, 8, 17, 10] + return data_samples_add + +@pytest.fixture +def data_samples_hwf(): + data_samples_hwf = ListData() + data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] + data_samples_hwf.pred_prob = [None, None, None, None] + data_samples_hwf.Y = [3, 64, 65, 3.17] + return data_samples_hwf + +class AddKB(KBBase): + def __init__(self, pseudo_label_list=list(range(10)), + use_cache=False): + super().__init__(pseudo_label_list, use_cache=use_cache) + + def logic_forward(self, nums): + return sum(nums) + +class AddGroundKB(GroundKB): + def __init__(self, pseudo_label_list=list(range(10)), + GKB_len_list=[2]): + super().__init__(pseudo_label_list, GKB_len_list) + + def logic_forward(self, nums): + return sum(nums) + +class HwfKB(KBBase): + def __init__( + self, + pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", + "+", "-", "times", "div"], + max_err=1e-3, + use_cache=False, + ): + super().__init__(pseudo_label_list, max_err, use_cache) + + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: + return False + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: + return False + return True + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return None + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) + formula = [mapping[f] for f in formula] + return eval("".join(formula)) + +class HedKB(PrologKB): + def __init__(self, pseudo_label_list, pl_file): + super().__init__(pseudo_label_list, pl_file) + + def consist_rule(self, exs, rules): + rules = str(rules).replace("'", "") + pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) + return len(list(self.prolog.query(pl_query))) != 0 + +@pytest.fixture +def kb_add(): + return AddKB() + +@pytest.fixture +def kb_add_cache(): + return AddKB(use_cache=True) + +@pytest.fixture +def kb_add_ground(): + return AddGroundKB() + +@pytest.fixture +def kb_add_prolog(): + kb = PrologKB(pseudo_label_list=list(range(10)), + pl_file="examples/mnist_add/datasets/add.pl") + return kb + +@pytest.fixture +def kb_hed(): + kb = HedKB( + pseudo_label_list=[1, 0, "+", "="], + pl_file="examples/hed/datasets/learn_add.pl", + ) + return kb + +@pytest.fixture +def reasoner_instance(kb_add): + return ReasonerBase(kb_add, "confidence") \ No newline at end of file diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py new file mode 100644 index 0000000..d6b0a78 --- /dev/null +++ b/tests/test_reasoning.py @@ -0,0 +1,147 @@ +import pytest +from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase + +class TestKBBase(object): + def test_init(self, kb_add): + assert kb_add.pseudo_label_list == list(range(10)) + + def test_init_cache(self, kb_add_cache): + assert kb_add_cache.pseudo_label_list == list(range(10)) + assert kb_add_cache.use_cache == True + + def test_logic_forward(self, kb_add): + result = kb_add.logic_forward([1, 2]) + assert result == 3 + + def test_revise_at_idx(self, kb_add): + result = kb_add.revise_at_idx([1, 2], 2, [0]) + assert result == [[0, 2]] + + def test_abduce_candidates(self, kb_add): + result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2) + assert result == [[1, 0]] + +class TestGroundKB(object): + def test_init(self, kb_add_ground): + assert kb_add_ground.pseudo_label_list == list(range(10)) + assert kb_add_ground.GKB_len_list == [2] + assert kb_add_ground.GKB + + def test_logic_forward_ground(self, kb_add_ground): + result = kb_add_ground.logic_forward([1, 2]) + assert result == 3 + + def test_abduce_candidates_ground(self, kb_add_ground): + result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2) + assert result == [(1, 0)] + +class TestPrologKB(object): + def test_init_pl1(self, kb_add_prolog): + assert kb_add_prolog.pseudo_label_list == list(range(10)) + assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl" + + def test_init_pl2(self, kb_hed): + assert kb_hed.pseudo_label_list == [1, 0, "+", "="] + assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl" + + def test_prolog_file_not_exist(self): + pseudo_label_list = [1, 2] + non_existing_file = "path/to/non_existing_file.pl" + with pytest.raises(FileNotFoundError) as excinfo: + PrologKB(pseudo_label_list=pseudo_label_list, + pl_file=non_existing_file) + assert non_existing_file in str(excinfo.value) + + def test_logic_forward_pl1(self, kb_add_prolog): + result = kb_add_prolog.logic_forward([1, 2]) + assert result == 3 + + def test_logic_forward_pl2(self, kb_hed): + consist_exs = [ + [1, 1, "+", 0, "=", 1, 1], + [1, "+", 1, "=", 1, 0], + [0, "+", 0, "=", 0], + ] + inconsist_exs = [ + [1, 1, "+", 0, "=", 1, 1], + [1, "+", 1, "=", 1, 0], + [0, "+", 0, "=", 0], + [0, "+", 0, "=", 1], + ] + assert kb_hed.logic_forward(consist_exs) == True + assert kb_hed.logic_forward(inconsist_exs) == False + + def test_revise_at_idx(self, kb_add_prolog): + result = kb_add_prolog.revise_at_idx([1, 2], 2, [0]) + assert result == [[0, 2]] + +class TestReaonser(object): + def test_reasoner_init(self, reasoner_instance): + assert reasoner_instance.dist_func == "confidence" + + def test_invalid_dist_funce(kb_add): + with pytest.raises(NotImplementedError) as excinfo: + ReasonerBase(kb_add, "invalid_dist_func") + assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value) + + +class test_batch_abduce(object): + def test_batch_abduce_add(self, kb_add, data_samples_add): + reasoner = ReasonerBase(kb_add, "confidence") + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) + assert res == [[1, 7], [7, 1], [], [1, 9]] + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) + assert res == [[1, 7], [7, 1], [], [1, 9]] + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) + assert res == [[1, 7], [7, 1], [8, 9], [1, 9]] + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) + assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] + + def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): + reasoner = ReasonerBase(kb_add_ground, "confidence") + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) + assert res == [(1, 7), (7, 1), [], (1, 9)] + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) + assert res == [(1, 7), (7, 1), [], (1, 9)] + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) + assert res == [(1, 7), (7, 1), (8, 9), (1, 9)] + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) + assert res == [(1, 7), (7, 1), (8, 9), (7, 3)] + + def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): + reasoner = ReasonerBase(kb_add_prolog, "confidence") + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) + assert res == [[1, 7], [7, 1], [], [1, 9]] + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) + assert res == [[1, 7], [7, 1], [], [1, 9]] + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) + assert res == [[1, 7], [7, 1], [8, 9], [1, 9]] + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) + assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] + + def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): + reasoner = ReasonerBase(kb_add_prolog, "confidence", use_zoopt=True) + res = reasoner.batch_abduce(data_samples_add, max_revision=1) + assert res == [[1, 7], [7, 1], [], [1, 9]] + res = reasoner.batch_abduce(data_samples_add, max_revision=2) + assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] + + def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): + reasoner = ReasonerBase(kb_hwf1, "hamming") + res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) + assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) + assert res == [['1', '+', '2'], [], [], []] + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) + assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] + + def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): + reasoner = ReasonerBase(kb_hwf2, "hamming") + res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) + assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) + assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']] + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) + assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] + + \ No newline at end of file