|
|
|
@@ -0,0 +1,147 @@ |
|
|
|
import pytest |
|
|
|
from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase |
|
|
|
|
|
|
|
class TestKBBase(object): |
|
|
|
def test_init(self, kb_add): |
|
|
|
assert kb_add.pseudo_label_list == list(range(10)) |
|
|
|
|
|
|
|
def test_init_cache(self, kb_add_cache): |
|
|
|
assert kb_add_cache.pseudo_label_list == list(range(10)) |
|
|
|
assert kb_add_cache.use_cache == True |
|
|
|
|
|
|
|
def test_logic_forward(self, kb_add): |
|
|
|
result = kb_add.logic_forward([1, 2]) |
|
|
|
assert result == 3 |
|
|
|
|
|
|
|
def test_revise_at_idx(self, kb_add): |
|
|
|
result = kb_add.revise_at_idx([1, 2], 2, [0]) |
|
|
|
assert result == [[0, 2]] |
|
|
|
|
|
|
|
def test_abduce_candidates(self, kb_add): |
|
|
|
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2) |
|
|
|
assert result == [[1, 0]] |
|
|
|
|
|
|
|
class TestGroundKB(object): |
|
|
|
def test_init(self, kb_add_ground): |
|
|
|
assert kb_add_ground.pseudo_label_list == list(range(10)) |
|
|
|
assert kb_add_ground.GKB_len_list == [2] |
|
|
|
assert kb_add_ground.GKB |
|
|
|
|
|
|
|
def test_logic_forward_ground(self, kb_add_ground): |
|
|
|
result = kb_add_ground.logic_forward([1, 2]) |
|
|
|
assert result == 3 |
|
|
|
|
|
|
|
def test_abduce_candidates_ground(self, kb_add_ground): |
|
|
|
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2) |
|
|
|
assert result == [(1, 0)] |
|
|
|
|
|
|
|
class TestPrologKB(object): |
|
|
|
def test_init_pl1(self, kb_add_prolog): |
|
|
|
assert kb_add_prolog.pseudo_label_list == list(range(10)) |
|
|
|
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl" |
|
|
|
|
|
|
|
def test_init_pl2(self, kb_hed): |
|
|
|
assert kb_hed.pseudo_label_list == [1, 0, "+", "="] |
|
|
|
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl" |
|
|
|
|
|
|
|
def test_prolog_file_not_exist(self): |
|
|
|
pseudo_label_list = [1, 2] |
|
|
|
non_existing_file = "path/to/non_existing_file.pl" |
|
|
|
with pytest.raises(FileNotFoundError) as excinfo: |
|
|
|
PrologKB(pseudo_label_list=pseudo_label_list, |
|
|
|
pl_file=non_existing_file) |
|
|
|
assert non_existing_file in str(excinfo.value) |
|
|
|
|
|
|
|
def test_logic_forward_pl1(self, kb_add_prolog): |
|
|
|
result = kb_add_prolog.logic_forward([1, 2]) |
|
|
|
assert result == 3 |
|
|
|
|
|
|
|
def test_logic_forward_pl2(self, kb_hed): |
|
|
|
consist_exs = [ |
|
|
|
[1, 1, "+", 0, "=", 1, 1], |
|
|
|
[1, "+", 1, "=", 1, 0], |
|
|
|
[0, "+", 0, "=", 0], |
|
|
|
] |
|
|
|
inconsist_exs = [ |
|
|
|
[1, 1, "+", 0, "=", 1, 1], |
|
|
|
[1, "+", 1, "=", 1, 0], |
|
|
|
[0, "+", 0, "=", 0], |
|
|
|
[0, "+", 0, "=", 1], |
|
|
|
] |
|
|
|
assert kb_hed.logic_forward(consist_exs) == True |
|
|
|
assert kb_hed.logic_forward(inconsist_exs) == False |
|
|
|
|
|
|
|
def test_revise_at_idx(self, kb_add_prolog): |
|
|
|
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0]) |
|
|
|
assert result == [[0, 2]] |
|
|
|
|
|
|
|
class TestReaonser(object): |
|
|
|
def test_reasoner_init(self, reasoner_instance): |
|
|
|
assert reasoner_instance.dist_func == "confidence" |
|
|
|
|
|
|
|
def test_invalid_dist_funce(kb_add): |
|
|
|
with pytest.raises(NotImplementedError) as excinfo: |
|
|
|
ReasonerBase(kb_add, "invalid_dist_func") |
|
|
|
assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value) |
|
|
|
|
|
|
|
|
|
|
|
class test_batch_abduce(object): |
|
|
|
def test_batch_abduce_add(self, kb_add, data_samples_add): |
|
|
|
reasoner = ReasonerBase(kb_add, "confidence") |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) |
|
|
|
assert res == [[1, 7], [7, 1], [], [1, 9]] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) |
|
|
|
assert res == [[1, 7], [7, 1], [], [1, 9]] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) |
|
|
|
assert res == [[1, 7], [7, 1], [8, 9], [1, 9]] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) |
|
|
|
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] |
|
|
|
|
|
|
|
def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): |
|
|
|
reasoner = ReasonerBase(kb_add_ground, "confidence") |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) |
|
|
|
assert res == [(1, 7), (7, 1), [], (1, 9)] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) |
|
|
|
assert res == [(1, 7), (7, 1), [], (1, 9)] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) |
|
|
|
assert res == [(1, 7), (7, 1), (8, 9), (1, 9)] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) |
|
|
|
assert res == [(1, 7), (7, 1), (8, 9), (7, 3)] |
|
|
|
|
|
|
|
def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): |
|
|
|
reasoner = ReasonerBase(kb_add_prolog, "confidence") |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) |
|
|
|
assert res == [[1, 7], [7, 1], [], [1, 9]] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) |
|
|
|
assert res == [[1, 7], [7, 1], [], [1, 9]] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) |
|
|
|
assert res == [[1, 7], [7, 1], [8, 9], [1, 9]] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) |
|
|
|
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] |
|
|
|
|
|
|
|
def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): |
|
|
|
reasoner = ReasonerBase(kb_add_prolog, "confidence", use_zoopt=True) |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1) |
|
|
|
assert res == [[1, 7], [7, 1], [], [1, 9]] |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2) |
|
|
|
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]] |
|
|
|
|
|
|
|
def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): |
|
|
|
reasoner = ReasonerBase(kb_hwf1, "hamming") |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) |
|
|
|
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) |
|
|
|
assert res == [['1', '+', '2'], [], [], []] |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) |
|
|
|
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] |
|
|
|
|
|
|
|
def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): |
|
|
|
reasoner = ReasonerBase(kb_hwf2, "hamming") |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) |
|
|
|
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) |
|
|
|
assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']] |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) |
|
|
|
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] |
|
|
|
|
|
|
|
|