diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 6d35729..7e3af8a 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -76,7 +76,7 @@ class Reasoner: if isinstance(dist_func, str): if dist_func not in ["hamming", "confidence", "avg_confidence"]: raise NotImplementedError( - 'Valid options for predefined dist_func include "hamming" ' + 'Valid options for predefined dist_func include "hamming", ' + f'"confidence" and "avg_confidence", but got {dist_func}.' ) return diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index d3a08b6..04b1f84 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -114,8 +114,10 @@ class TestReaonser(object): def test_invalid_predefined_dist_func(self, kb_add): with pytest.raises(NotImplementedError) as excinfo: Reasoner(kb_add, "invalid_dist_func") - assert 'Valid options for predefined dist_func include "hamming" and "confidence"' in str( - excinfo.value + assert ( + 'Valid options for predefined dist_func include "hamming", "confidence" ' + + 'and "avg_confidence"' + in str(excinfo.value) ) def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results):