From 1379d041912a2f48edf37bb7e080602fd48738a2 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 27 Dec 2023 23:02:09 +0800 Subject: [PATCH] [FIX] pass pytest in python 3.11 --- tests/test_abl_model.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_abl_model.py b/tests/test_abl_model.py index 320aa44..c9e256e 100644 --- a/tests/test_abl_model.py +++ b/tests/test_abl_model.py @@ -11,24 +11,22 @@ class TestABLModel(object): """Test the initialization method of the ABLModel class.""" invalid_base_model = Mock(spec=[]) - with pytest.raises(NotImplementedError): ABLModel(invalid_base_model) - fit = Mock(return_value=1.0) - predict = Mock(return_value=np.array(1.0)) - - invalid_base_model = Mock(spec=fit) + invalid_base_model = Mock(spec=["fit"]) + invalid_base_model.fit.return_value = 1.0 with pytest.raises(NotImplementedError): ABLModel(invalid_base_model) - invalid_base_model = Mock(spec=predict) + invalid_base_model = Mock(spec=["predict"]) + invalid_base_model.predict.return_value = np.array(1.0) with pytest.raises(NotImplementedError): ABLModel(invalid_base_model) base_model = Mock(spec=["fit", "predict"]) - base_model.fit = fit - base_model.predict = predict + base_model.fit.return_value = 1.0 + base_model.predict.return_value = np.array(1.0) model = ABLModel(base_model) assert hasattr(model, "base_model"), "The model should have a 'base_model' attribute."