| @@ -12,6 +12,34 @@ from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
| class Callback: | class Callback: | ||||
| r""" | r""" | ||||
| 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | ||||
| callback 调用时机顺序大概如下 | |||||
| Trainer.__init__(): | |||||
| on_after_trainer_initialized() | |||||
| Trainer.run(): | |||||
| if num_eval_sanity_batch>0: | |||||
| on_sanity_check_begin() # 如果设置了num_eval_sanity_batch | |||||
| on_sanity_check_end() | |||||
| try: | |||||
| on_train_begin() | |||||
| while cur_epoch_idx < n_epochs: | |||||
| on_train_epoch_begin() | |||||
| while batch_idx_in_epoch<=num_batches_per_epoch: | |||||
| on_fetch_data_begin() | |||||
| on_fetch_data_end() | |||||
| on_train_batch_begin() | |||||
| on_before_backward() | |||||
| on_after_backward() | |||||
| on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
| on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
| on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
| on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
| on_train_batch_end() | |||||
| on_train_epoch_end() | |||||
| except BaseException: | |||||
| self.on_exception() | |||||
| finally: | |||||
| on_train_end() | |||||
| 其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 | |||||
| """ | """ | ||||
| def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
| @@ -221,9 +249,9 @@ class Callback: | |||||
| """ | """ | ||||
| pass | pass | ||||
| def on_validate_begin(self, trainer): | |||||
| def on_evaluate_begin(self, trainer): | |||||
| """ | """ | ||||
| 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||||
| 在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 | |||||
| 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | ||||
| :param trainer: | :param trainer: | ||||
| @@ -231,9 +259,9 @@ class Callback: | |||||
| """ | """ | ||||
| pass | pass | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| """ | """ | ||||
| 结束 validate 时调用,并把 validate 的结果传入。 | |||||
| 结束 evaluate 时调用,并把 evaluate 的结果传入。 | |||||
| :param trainer: | :param trainer: | ||||
| :param results: Evaluate 的结果,一般是个 dict 。 | :param results: Evaluate 的结果,一般是个 dict 。 | ||||
| @@ -96,8 +96,8 @@ class Events(EventEnum): | |||||
| on_after_optimizers_step = "on_after_optimizers_step" | on_after_optimizers_step = "on_after_optimizers_step" | ||||
| on_before_zero_grad = "on_before_zero_grad" | on_before_zero_grad = "on_before_zero_grad" | ||||
| on_after_zero_grad = "on_after_zero_grad" | on_after_zero_grad = "on_after_zero_grad" | ||||
| on_validate_begin = "on_validate_begin" | |||||
| on_validate_end = "on_validate_end" | |||||
| on_evaluate_begin = "on_evaluate_begin" | |||||
| on_evaluate_end = "on_evaluate_end" | |||||
| class EventsList: | class EventsList: | ||||
| @@ -281,9 +281,9 @@ class CallbackManager: | |||||
| pass | pass | ||||
| @_transfer | @_transfer | ||||
| def on_validate_begin(self, trainer): | |||||
| def on_evaluate_begin(self, trainer): | |||||
| pass | pass | ||||
| @_transfer | @_transfer | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| pass | pass | ||||
| @@ -114,7 +114,7 @@ class CheckpointCallback(Callback): | |||||
| if self.topk_saver.topk_queue and trainer.evaluator is None: | if self.topk_saver.topk_queue and trainer.evaluator is None: | ||||
| logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") | logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| # 如果发生了保存,则返回的 folder 不为 None | # 如果发生了保存,则返回的 folder 不为 None | ||||
| folder = self.topk_saver.save_topk(trainer, results) | folder = self.topk_saver.save_topk(trainer, results) | ||||
| @@ -16,13 +16,13 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | 果(字典类型),返回一个 float 值作为 monitor 的结果。 | ||||
| :param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
| :param patience: 多少次 validate 不没有提升就停止。 | |||||
| :param patience: 多少次 evaluate 不没有提升就停止。 | |||||
| """ | """ | ||||
| super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | ||||
| self.wait = 0 | self.wait = 0 | ||||
| self.patience = patience | self.patience = patience | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| monitor_value = self.get_monitor_value(results) | monitor_value = self.get_monitor_value(results) | ||||
| if monitor_value is None: | if monitor_value is None: | ||||
| return | return | ||||
| @@ -32,13 +32,13 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
| self.wait += 1 | self.wait += 1 | ||||
| def on_fetch_data_begin(self, trainer): | def on_fetch_data_begin(self, trainer): | ||||
| # 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
| # 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
| if self.wait >= self.patience: | if self.wait >= self.patience: | ||||
| raise EarlyStopException(f"After {self.wait} validations, no improvement for " | raise EarlyStopException(f"After {self.wait} validations, no improvement for " | ||||
| f"metric `{self._real_monitor}`") | f"metric `{self._real_monitor}`") | ||||
| def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
| # 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
| # 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
| if self.wait >= self.patience: | if self.wait >= self.patience: | ||||
| raise EarlyStopException(f"After {self.wait} validations, no improvement for " | raise EarlyStopException(f"After {self.wait} validations, no improvement for " | ||||
| f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | ||||
| @@ -216,6 +216,6 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||||
| _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | ||||
| self.execute_fn = execute_fn | self.execute_fn = execute_fn | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| if self.is_better_results(results): | if self.is_better_results(results): | ||||
| self.execute_fn() | self.execute_fn() | ||||
| @@ -76,7 +76,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
| if self.real_save_folder: | if self.real_save_folder: | ||||
| trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| @@ -95,27 +95,14 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| self.buffer.seek(0) | self.buffer.seek(0) | ||||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
| trainer.driver.barrier() | |||||
| self._delete_after_after(trainer) | |||||
| def _delete_after_after(self, trainer): | |||||
| trainer.driver.barrier() | |||||
| if self.delete_after_after: | if self.delete_after_after: | ||||
| if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
| # 只需要 rank 0 执行删除。 | |||||
| logger.info(f"Deleting {self.real_save_folder}...") | |||||
| shutil.rmtree(self.real_save_folder) | |||||
| try: | |||||
| # 如果是 emtpy 的,就会被删除掉 | |||||
| os.rmdir(self.save_folder) | |||||
| except: | |||||
| pass | |||||
| elif hasattr(self, 'buffer'): | |||||
| self.buffer.close() | |||||
| del self.buffer | |||||
| def on_exception(self, trainer, exception): | |||||
| if self.delete_after_after: | |||||
| if self.real_save_folder: # 这里,谁处异常,谁删除 | |||||
| if self.real_save_folder: | |||||
| logger.info(f"Deleting {self.real_save_folder}...") | logger.info(f"Deleting {self.real_save_folder}...") | ||||
| shutil.rmtree(self.real_save_folder) | |||||
| shutil.rmtree(self.real_save_folder, ignore_errors=True) | |||||
| try: | try: | ||||
| # 如果是 emtpy 的,就会被删除掉 | # 如果是 emtpy 的,就会被删除掉 | ||||
| os.rmdir(self.save_folder) | os.rmdir(self.save_folder) | ||||
| @@ -31,8 +31,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| :param dataloaders: 需要评估的数据 | :param dataloaders: 需要评估的数据 | ||||
| :param metrics: 使用的 metrics 。 | :param metrics: 使用的 metrics 。 | ||||
| :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch | |||||
| evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||||
| :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch evaluate 一次;(2) 为正整数则表示每隔几个 batch | |||||
| evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||||
| 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | ||||
| :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | ||||
| 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | ||||
| @@ -128,7 +128,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) | results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) | ||||
| self.topk_saver.get_monitor_value(results) | self.topk_saver.get_monitor_value(results) | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
| results = self.evaluator.run() | results = self.evaluator.run() | ||||
| self.topk_saver.save_topk(trainer, results) | self.topk_saver.save_topk(trainer, results) | ||||
| @@ -137,8 +137,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| if self.watch_monitor is not None: | if self.watch_monitor is not None: | ||||
| return | return | ||||
| if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | ||||
| validate_every = -self.evaluate_every | |||||
| if trainer.cur_epoch_idx % validate_every == 0: | |||||
| evaluate_every = -self.evaluate_every | |||||
| if trainer.cur_epoch_idx % evaluate_every == 0: | |||||
| results = self.evaluator.run() | results = self.evaluator.run() | ||||
| self.topk_saver.save_topk(trainer, results) | self.topk_saver.save_topk(trainer, results) | ||||
| @@ -100,7 +100,7 @@ class RichCallback(ProgressCallback): | |||||
| self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | ||||
| advance=self.epoch_bar_update_advance, refresh=True) | advance=self.epoch_bar_update_advance, refresh=True) | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| if len(results)==0: | if len(results)==0: | ||||
| return | return | ||||
| rule_style = '' | rule_style = '' | ||||
| @@ -122,9 +122,6 @@ class RichCallback(ProgressCallback): | |||||
| else: | else: | ||||
| self.progress_bar.print(results) | self.progress_bar.print(results) | ||||
| def on_exception(self, trainer, exception): | |||||
| self.clear_tasks() | |||||
| def clear_tasks(self): | def clear_tasks(self): | ||||
| for key, taskid in self.task2id.items(): | for key, taskid in self.task2id.items(): | ||||
| self.progress_bar.destroy_task(taskid) | self.progress_bar.destroy_task(taskid) | ||||
| @@ -178,7 +175,7 @@ class RawTextCallback(ProgressCallback): | |||||
| f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' | f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' | ||||
| logger.info(text) | logger.info(text) | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| if len(results)==0: | if len(results)==0: | ||||
| return | return | ||||
| base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
| @@ -43,7 +43,7 @@ class TrainBatchLoop(Loop): | |||||
| trainer.check_batch_step_fn() | trainer.check_batch_step_fn() | ||||
| trainer.on_train_batch_end() | trainer.on_train_batch_end() | ||||
| trainer.step_validate() | |||||
| trainer.step_evaluate() | |||||
| trainer.batch_idx_in_epoch = 0 | trainer.batch_idx_in_epoch = 0 | ||||
| @staticmethod | @staticmethod | ||||
| @@ -339,11 +339,11 @@ class Trainer(TrainerEventTrigger): | |||||
| self.num_batches_per_epoch = len(self.dataloader) | self.num_batches_per_epoch = len(self.dataloader) | ||||
| self.total_batches = self.num_batches_per_epoch * self.n_epochs | self.total_batches = self.num_batches_per_epoch * self.n_epochs | ||||
| self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | ||||
| self.on_train_begin() | |||||
| self.driver.barrier() | |||||
| self.driver.zero_grad(self.set_grad_to_none) | |||||
| try: | try: | ||||
| self.on_train_begin() | |||||
| self.driver.barrier() | |||||
| self.driver.zero_grad(self.set_grad_to_none) | |||||
| while self.cur_epoch_idx < self.n_epochs: | while self.cur_epoch_idx < self.n_epochs: | ||||
| # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | ||||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | ||||
| @@ -356,10 +356,8 @@ class Trainer(TrainerEventTrigger): | |||||
| self.cur_epoch_idx += 1 | self.cur_epoch_idx += 1 | ||||
| self.on_train_epoch_end() | self.on_train_epoch_end() | ||||
| self.driver.barrier() | self.driver.barrier() | ||||
| self.epoch_validate() | |||||
| self.epoch_evaluate() | |||||
| self.driver.barrier() | self.driver.barrier() | ||||
| self.on_train_end() | |||||
| self.driver.barrier() | |||||
| except EarlyStopException as e: | except EarlyStopException as e: | ||||
| logger.info(f"Catch early stop exception: {e.msg}.") | logger.info(f"Catch early stop exception: {e.msg}.") | ||||
| @@ -373,17 +371,20 @@ class Trainer(TrainerEventTrigger): | |||||
| self.driver.on_exception() | self.driver.on_exception() | ||||
| self.on_exception(e) | self.on_exception(e) | ||||
| raise e | raise e | ||||
| finally: | |||||
| self.on_train_end() | |||||
| self.driver.barrier() | |||||
| def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | ||||
| def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: | |||||
| trainer.on_validate_begin() | |||||
| _validate_res: dict = validate_fn() | |||||
| trainer.on_validate_end(_validate_res) | |||||
| def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: | |||||
| trainer.on_evaluate_begin() | |||||
| _evaluate_res: dict = evaluate_fn() | |||||
| trainer.on_evaluate_end(_evaluate_res) | |||||
| if self.evaluator is not None: | if self.evaluator is not None: | ||||
| self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
| self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
| def step_validate(self): | |||||
| def step_evaluate(self): | |||||
| """ | """ | ||||
| 在每个 batch 结束后调用,根据设置执行 evaluate 。 | 在每个 batch 结束后调用,根据设置执行 evaluate 。 | ||||
| @@ -396,7 +397,7 @@ class Trainer(TrainerEventTrigger): | |||||
| elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: | elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: | ||||
| self.run_evaluate() | self.run_evaluate() | ||||
| def epoch_validate(self): | |||||
| def epoch_evaluate(self): | |||||
| """ | """ | ||||
| 在每个 epoch 结束后调用,根据设置执行 evaluate 。 | 在每个 epoch 结束后调用,根据设置执行 evaluate 。 | ||||
| @@ -404,8 +405,8 @@ class Trainer(TrainerEventTrigger): | |||||
| """ | """ | ||||
| if self.evaluator is not None: | if self.evaluator is not None: | ||||
| if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | ||||
| validate_every = -self.evaluate_every | |||||
| if self.cur_epoch_idx % validate_every == 0: | |||||
| evaluate_every = -self.evaluate_every | |||||
| if self.cur_epoch_idx % evaluate_every == 0: | |||||
| self.run_evaluate() | self.run_evaluate() | ||||
| def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | ||||
| @@ -81,12 +81,12 @@ class TrainerEventTrigger: | |||||
| def on_after_zero_grad(self, optimizers): | def on_after_zero_grad(self, optimizers): | ||||
| self.callback_manager.on_after_zero_grad(self, optimizers) | self.callback_manager.on_after_zero_grad(self, optimizers) | ||||
| def on_validate_begin(self): | |||||
| self.callback_manager.on_validate_begin(self) | |||||
| def on_evaluate_begin(self): | |||||
| self.callback_manager.on_evaluate_begin(self) | |||||
| def on_validate_end(self, results): | |||||
| def on_evaluate_end(self, results): | |||||
| self.trainer_state.save_on_this_step = True | self.trainer_state.save_on_this_step = True | ||||
| self.callback_manager.on_validate_end(self, results) | |||||
| self.callback_manager.on_evaluate_end(self, results) | |||||
| class _TruncatedDataLoader: | class _TruncatedDataLoader: | ||||
| @@ -126,8 +126,8 @@ class _TruncatedDataLoader: | |||||
| return getattr(self.dataloader, item) | return getattr(self.dataloader, item) | ||||
| def check_evaluate_every(validate_every): | |||||
| if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||||
| def check_evaluate_every(evaluate_every): | |||||
| if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): | |||||
| raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
| if callable(validate_every): | |||||
| _check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
| if callable(evaluate_every): | |||||
| _check_valid_parameters_number(evaluate_every, expected_params=['trainer']) | |||||
| @@ -63,7 +63,7 @@ class JittorDriver(Driver): | |||||
| def check_evaluator_mode(self, mode: str): | def check_evaluator_mode(self, mode: str): | ||||
| model = self.unwrap_model() | model = self.unwrap_model() | ||||
| if mode == "validate": | |||||
| if mode == "evaluate": | |||||
| if not hasattr(model, "evaluate_step"): | if not hasattr(model, "evaluate_step"): | ||||
| if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
| logger.warning_once( | logger.warning_once( | ||||
| @@ -173,6 +173,19 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
| kwargs["extra"] = extra | kwargs["extra"] = extra | ||||
| return kwargs | return kwargs | ||||
| def setLevel(self, level) -> None: | |||||
| """ | |||||
| 设置当前 logger 以及其 handler 的 log 级别 | |||||
| :param level: | |||||
| :return: | |||||
| """ | |||||
| if isinstance(level, str): | |||||
| level = level.upper() | |||||
| super().setLevel(level) | |||||
| for handler in self.handlers: | |||||
| handler.setLevel(level) | |||||
| def _get_level(level): | def _get_level(level): | ||||
| if not isinstance(level, int): | if not isinstance(level, int): | ||||
| @@ -38,7 +38,7 @@ class RecordMetricCallback(Callback): | |||||
| self.metric_threshold = metric_threshold | self.metric_threshold = metric_threshold | ||||
| self.metric_begin_value = None | self.metric_begin_value = None | ||||
| def on_validate_end(self, trainer, results): | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| self.metric = results[self.monitor] | self.metric = results[self.monitor] | ||||
| if self.metric_begin_value is None: | if self.metric_begin_value is None: | ||||
| self.metric_begin_value = self.metric | self.metric_begin_value = self.metric | ||||
| @@ -113,11 +113,11 @@ class RecordTrainerEventTriggerCallback(Callback): | |||||
| def on_after_zero_grad(self, trainer, optimizers): | def on_after_zero_grad(self, trainer, optimizers): | ||||
| print("on_after_zero_grad") | print("on_after_zero_grad") | ||||
| def on_validate_begin(self, trainer): | |||||
| print("on_validate_begin") | |||||
| def on_evaluate_begin(self, trainer): | |||||
| print("on_evaluate_begin") | |||||
| def on_validate_end(self, trainer, results): | |||||
| print("on_validate_end") | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| print("on_evaluate_end") | |||||