From c8e8ff4a8cd80672422ce6463ae575b2aa56d17d Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 09:47:08 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95=E4=BE=8B?= =?UTF-8?q?=E4=B8=AD=E7=9A=84Events=E4=B8=BAEvent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/callbacks/test_callback_event.py | 8 ++++---- tests/core/controllers/test_trainer_other_things.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/core/callbacks/test_callback_event.py b/tests/core/callbacks/test_callback_event.py index 8a38670a..219ccafd 100644 --- a/tests/core/callbacks/test_callback_event.py +++ b/tests/core/callbacks/test_callback_event.py @@ -162,7 +162,7 @@ class TestCallbackEvents: def test_every(self): # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; - event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; + event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1; @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) def _fn(data): return data @@ -174,7 +174,7 @@ class TestCallbackEvents: _res.append(cu_res) assert _res == list(range(100)) - event_state = Events.on_train_begin(every=10) + event_state = Event.on_train_begin(every=10) @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) def _fn(data): return data @@ -187,7 +187,7 @@ class TestCallbackEvents: assert _res == [w - 1 for w in range(10, 101, 10)] def test_once(self): - event_state = Events.on_train_begin(once=10) + event_state = Event.on_train_begin(once=10) @Filter(once=event_state.once) def _fn(data): @@ -220,7 +220,7 @@ def test_callback_events_torch(): return True return False - event_state = Events.on_train_begin(filter_fn=filter_fn) + event_state = Event.on_train_begin(filter_fn=filter_fn) @Filter(filter_fn=event_state.filter_fn) def _fn(trainer, data): diff --git a/tests/core/controllers/test_trainer_other_things.py b/tests/core/controllers/test_trainer_other_things.py index b010058b..3d9a5037 100644 --- a/tests/core/controllers/test_trainer_other_things.py +++ b/tests/core/controllers/test_trainer_other_things.py @@ -1,22 +1,22 @@ import pytest from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.callbacks import Events +from fastNLP.core.callbacks import Event from tests.helpers.utils import magic_argv_env_context @magic_argv_env_context def test_trainer_torch_without_evaluator(): - @Trainer.on(Events.on_train_epoch_begin(every=10), marker="test_trainer_other_things") + @Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things") def fn1(trainer): pass - @Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things") + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn2(trainer, batch, indices): pass with pytest.raises(BaseException): - @Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things") + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn3(trainer, batch): pass