From a96cdfd6c2e019413c0f1d1d84373ebc5e3c4d0f Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Fri, 15 Dec 2023 13:21:22 +0800 Subject: [PATCH] [ENH] Change dist_func to four parameters --- tests/test_reasoning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index 2373c72..9c060fa 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -101,7 +101,7 @@ class TestReaonser(object): excinfo.value ) - def random_dist(self, data_sample, candidates, reasoning_results): + def random_dist(self, data_sample, candidates, candidate_idxs, reasoning_results): cost_list = [np.random.rand() for _ in candidates] return cost_list @@ -113,14 +113,14 @@ class TestReaonser(object): cost_list = np.array([np.random.rand() for _ in candidates]) return cost_list - def invalid_dist2(self, data_sample, candidates, reasoning_results): + def invalid_dist2(self, data_sample, candidates, candidate_idxs, reasoning_results): cost_list = np.array([np.random.rand() for _ in candidates]) return np.append(cost_list, np.random.rand()) def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add): with pytest.raises(ValueError) as excinfo: Reasoner(kb_add, self.invalid_dist1) - assert 'User-defined dist_func must have exactly three parameters' in str( + assert 'User-defined dist_func must have exactly four parameters' in str( excinfo.value ) with pytest.raises(ValueError) as excinfo: