diff --git a/tests/conftest.py b/tests/conftest.py index fd427ab..088ce50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.optim as optim from abl.learning import BasicNN -from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase +from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner from abl.structures import ListData from examples.models.nn import LeNet5 @@ -138,4 +138,4 @@ def kb_hed(): @pytest.fixture def reasoner_instance(kb_add): - return ReasonerBase(kb_add, "confidence") \ No newline at end of file + return Reasoner(kb_add, "confidence") \ No newline at end of file diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index 0eed8cb..4c2820b 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -1,5 +1,5 @@ import pytest -from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase +from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner class TestKBBase(object): def test_init(self, kb_add): @@ -18,7 +18,8 @@ class TestKBBase(object): assert result == [[0, 2]] def test_abduce_candidates(self, kb_add): - result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2) + result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, + require_more_revision=0) assert result == [[1, 0]] class TestGroundKB(object): @@ -32,7 +33,8 @@ class TestGroundKB(object): assert result == 3 def test_abduce_candidates_ground(self, kb_add_ground): - result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2) + result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2, + require_more_revision=0) assert result == [(1, 0)] class TestPrologKB(object): @@ -58,8 +60,9 @@ class TestPrologKB(object): def test_logic_forward_pl2(self, kb_hed): consist_exs = [ - [1, "+", 1, "=", 0], - [1, "+", 1, "=", 1], + [1, 1, "+", 0, "=", 1, 1], + [1, "+", 1, "=", 1, 0], + [0, "+", 0, "=", 0], ] inconsist_exs = [ [1, 1, "+", 0, "=", 1, 1], @@ -80,67 +83,67 @@ class TestReaonser(object): def test_invalid_dist_funce(kb_add): with pytest.raises(NotImplementedError) as excinfo: - ReasonerBase(kb_add, "invalid_dist_func") + Reasoner(kb_add, "invalid_dist_func") assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value) class test_batch_abduce(object): def test_batch_abduce_add(self, kb_add, data_samples_add): - reasoner = ReasonerBase(kb_add, "confidence") - res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) - assert res == [[1, 7], [7, 1], [], [1, 9]] - res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) - assert res == [[1, 7], [7, 1], [], [1, 9]] - res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) - assert res == [[1, 7], [7, 1], [8, 9], [1, 9]] - res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) - assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] + reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0) + reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1) + reasoner3 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=0) + reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1) + assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] + assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] + assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]] + assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): - reasoner = ReasonerBase(kb_add_ground, "confidence") - res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) - assert res == [(1, 7), (7, 1), [], (1, 9)] - res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) - assert res == [(1, 7), (7, 1), [], (1, 9)] - res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) - assert res == [(1, 7), (7, 1), (8, 9), (1, 9)] - res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) - assert res == [(1, 7), (7, 1), (8, 9), (7, 3)] + reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0) + reasoner2 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=1) + reasoner3 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=0) + reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1) + assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] + assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] + assert reasoner3.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (1, 9)] + assert reasoner4.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (7, 3)] def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): - reasoner = ReasonerBase(kb_add_prolog, "confidence") - res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) - assert res == [[1, 7], [7, 1], [], [1, 9]] - res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) - assert res == [[1, 7], [7, 1], [], [1, 9]] - res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) - assert res == [[1, 7], [7, 1], [8, 9], [1, 9]] - res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) - assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] + reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) + reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1) + reasoner3 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=0) + reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1) + assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] + assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] + assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]] + assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): - reasoner = ReasonerBase(kb_add_prolog, "confidence", use_zoopt=True) - res = reasoner.batch_abduce(data_samples_add, max_revision=1) - assert res == [[1, 7], [7, 1], [], [1, 9]] - res = reasoner.batch_abduce(data_samples_add, max_revision=2) - assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] + reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) + reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2) + assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] + assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): - reasoner = ReasonerBase(kb_hwf1, "hamming") - res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) + reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0) + reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0) + reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0) + res = reasoner1.batch_abduce(data_samples_hwf) assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] - res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) + res = reasoner2.batch_abduce(data_samples_hwf) assert res == [['1', '+', '2'], [], [], []] - res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) + res = reasoner3.batch_abduce(data_samples_hwf) assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): - reasoner = ReasonerBase(kb_hwf2, "hamming") - res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) + reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0) + reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0) + reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0) + res = reasoner1.batch_abduce(data_samples_hwf) assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] - res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) + res = reasoner2.batch_abduce(data_samples_hwf) assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']] - res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) + res = reasoner3.batch_abduce(data_samples_hwf) assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] \ No newline at end of file