diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 9ead6024..7b04d8ad 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -127,11 +127,12 @@ class CallbackManager: :param callback: 一个具体的 callback 实例; """ self.all_callbacks.append(callback) - for name, member in Event.__members__.items(): - _fn = getattr(callback, member.value) - if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): - self.callback_fns[member.value].append(_fn) - self.extract_callback_filter_state(callback.callback_name, _fn) + for name, member in Event.__dict__.items(): + if isinstance(member, staticmethod): + _fn = getattr(callback, name) + if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)): + self.callback_fns[name].append(_fn) + self.extract_callback_filter_state(callback.callback_name, _fn) def extract_callback_filter_state(self, callback_name, callback_fn): r""" diff --git a/tests/core/callbacks/test_callback_event.py b/tests/core/callbacks/test_callback_event.py index 219ccafd..765c4432 100644 --- a/tests/core/callbacks/test_callback_event.py +++ b/tests/core/callbacks/test_callback_event.py @@ -6,42 +6,6 @@ from fastNLP.core.callbacks.callback_event import Event, Filter class TestFilter: - def test_params_check(self): - # 顺利通过 - _filter1 = Filter(every=10) - _filter2 = Filter(once=10) - _filter3 = Filter(filter_fn=lambda: None) - - # 触发 ValueError - with pytest.raises(ValueError) as e: - _filter4 = Filter() - exec_msg = e.value.args[0] - assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter." - - # 触发 ValueError - with pytest.raises(ValueError) as e: - _filter5 = Filter(every=10, once=10) - exec_msg = e.value.args[0] - assert exec_msg == "These three values should be only set one." - - # 触发 TypeError - with pytest.raises(ValueError) as e: - _filter6 = Filter(every="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument every should be integer and greater than zero" - - # 触发 TypeError - with pytest.raises(ValueError) as e: - _filter7 = Filter(once="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument once should be integer and positive" - - # 触发 TypeError - with pytest.raises(TypeError) as e: - _filter7 = Filter(filter_fn="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument event_filter should be a callable" - def test_every_filter(self): # every = 10 @Filter(every=10)