diff --git a/tests/test_abl_model.py b/tests/test_abl_model.py index f6e78f2..320aa44 100644 --- a/tests/test_abl_model.py +++ b/tests/test_abl_model.py @@ -52,8 +52,8 @@ class TestABLModel(object): """Test the train method of the ABLModel class.""" model = ABLModel(base_model_instance) list_data_instance.abduced_idx = [[1, 2], [3, 4], [5, 6]] - loss = model.train(list_data_instance) - assert isinstance(loss, float), "Training should return a float value indicating the loss." + ins = model.train(list_data_instance) + assert ins == model.base_model, "Training should return the base model instance." def test_ablmodel_save_load(self, base_model_instance, tmp_path): """Test the save method of the ABLModel class.""" diff --git a/tests/test_basic_nn.py b/tests/test_basic_nn.py index fa642e7..0a6cafb 100644 --- a/tests/test_basic_nn.py +++ b/tests/test_basic_nn.py @@ -5,7 +5,6 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset - class TestBasicNN(object): @pytest.fixture def sample_data(self): @@ -40,12 +39,12 @@ class TestBasicNN(object): # Test fit with direct data X, y = sample_data - loss = basic_nn_instance.fit(X=list(X), y=list(y)) - assert isinstance(loss, float) + ins = basic_nn_instance.fit(X=list(X), y=list(y)) + assert ins == basic_nn_instance # Test fit with DataLoader - loss = basic_nn_instance.fit(data_loader=sample_data_loader_with_label) - assert isinstance(loss, float) + ins = basic_nn_instance.fit(data_loader=sample_data_loader_with_label) + assert ins == basic_nn_instance # Test invalid fit method input with pytest.raises(ValueError): diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index c4d29f5..3651cc3 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -14,19 +14,21 @@ class TestKBBase(object): def test_logic_forward(self, kb_add): result = kb_add.logic_forward([1, 2]) assert result == 3 + with pytest.raises(TypeError): + kb_add.logic_forward([1, 2], [0.1, -0.2, 0.2, -0.3]) def test_revise_at_idx(self, kb_add): - result = kb_add.revise_at_idx([0, 2], 2, []) + result = kb_add.revise_at_idx([0, 2], 2, [0.1, -0.2, 0.2, -0.3], []) assert result == [[0, 2]] - result = kb_add.revise_at_idx([1, 2], 2, []) + result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], []) assert result == [] - result = kb_add.revise_at_idx([1, 2], 2, [0, 1]) + result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0, 1]) assert result == [[0, 2], [1, 1], [2, 0]] def test_abduce_candidates(self, kb_add): - result = kb_add.abduce_candidates([0, 1], 1, max_revision_num=2, require_more_revision=0) + result = kb_add.abduce_candidates([0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0) assert result == [[0, 1]] - result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, require_more_revision=0) + result = kb_add.abduce_candidates([1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0) assert result == [[1, 0]] @@ -42,7 +44,7 @@ class TestGroundKB(object): def test_abduce_candidates_ground(self, kb_add_ground): result = kb_add_ground.abduce_candidates( - [1, 2], 1, max_revision_num=2, require_more_revision=0 + [1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0 ) assert result == [(1, 0)]