From 8017d8c854ba5ad5181a51a609d29cd70fa71855 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 3 May 2022 18:08:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=AF=B9=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E8=BF=87=E7=9A=84=20Events=20=E7=9A=84=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_manager.py | 11 ++++--- tests/core/callbacks/test_callback_event.py | 36 --------------------- 2 files changed, 6 insertions(+), 41 deletions(-) diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 9ead6024..7b04d8ad 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -127,11 +127,12 @@ class CallbackManager: :param callback: 一个具体的 callback 实例; """ self.all_callbacks.append(callback) - for name, member in Event.__members__.items(): - _fn = getattr(callback, member.value) - if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): - self.callback_fns[member.value].append(_fn) - self.extract_callback_filter_state(callback.callback_name, _fn) + for name, member in Event.__dict__.items(): + if isinstance(member, staticmethod): + _fn = getattr(callback, name) + if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)): + self.callback_fns[name].append(_fn) + self.extract_callback_filter_state(callback.callback_name, _fn) def extract_callback_filter_state(self, callback_name, callback_fn): r""" diff --git a/tests/core/callbacks/test_callback_event.py b/tests/core/callbacks/test_callback_event.py index 219ccafd..765c4432 100644 --- a/tests/core/callbacks/test_callback_event.py +++ b/tests/core/callbacks/test_callback_event.py @@ -6,42 +6,6 @@ from fastNLP.core.callbacks.callback_event import Event, Filter class TestFilter: - def test_params_check(self): - # 顺利通过 - _filter1 = Filter(every=10) - _filter2 = Filter(once=10) - _filter3 = Filter(filter_fn=lambda: None) - - # 触发 ValueError - with pytest.raises(ValueError) as e: - _filter4 = Filter() - exec_msg = e.value.args[0] - assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter." - - # 触发 ValueError - with pytest.raises(ValueError) as e: - _filter5 = Filter(every=10, once=10) - exec_msg = e.value.args[0] - assert exec_msg == "These three values should be only set one." - - # 触发 TypeError - with pytest.raises(ValueError) as e: - _filter6 = Filter(every="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument every should be integer and greater than zero" - - # 触发 TypeError - with pytest.raises(ValueError) as e: - _filter7 = Filter(once="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument once should be integer and positive" - - # 触发 TypeError - with pytest.raises(TypeError) as e: - _filter7 = Filter(filter_fn="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument event_filter should be a callable" - def test_every_filter(self): # every = 10 @Filter(every=10)