| @@ -263,7 +263,7 @@ class Trainer(TrainerEventTrigger): | |||||
| def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | ||||
| num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | ||||
| catch_KeyboardInterrupt=True): | |||||
| catch_KeyboardInterrupt=None): | |||||
| """ | """ | ||||
| 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | ||||
| 去保存断点重训的文件; | 去保存断点重训的文件; | ||||
| @@ -273,15 +273,17 @@ class Trainer(TrainerEventTrigger): | |||||
| :param resume_from: 从哪个路径下恢复 trainer 的状态 | :param resume_from: 从哪个路径下恢复 trainer 的状态 | ||||
| :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | ||||
| :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | ||||
| 行。 | |||||
| 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.driver.is_distributed(): | |||||
| if catch_KeyboardInterrupt: | |||||
| logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||||
| "driver. And we are gonna to set it to False.") | |||||
| catch_KeyboardInterrupt = False | |||||
| if catch_KeyboardInterrupt is None: | |||||
| catch_KeyboardInterrupt = not self.driver.is_distributed() | |||||
| else: | |||||
| if self.driver.is_distributed(): | |||||
| if catch_KeyboardInterrupt: | |||||
| logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||||
| "driver. And we are gonna to set it to False.") | |||||
| catch_KeyboardInterrupt = False | |||||
| self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) | self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) | ||||