1. 支持训练完成后自动恢复best ckpt,方便在不同测试集上进行测试
2. build_optimizer/build_lr_scheduler改为成员函数,方便重载(如模型分层lr)
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10348255
master
| @@ -216,6 +216,7 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| by_epoch (bool): Save best checkpoints by epoch or by iteration. | by_epoch (bool): Save best checkpoints by epoch or by iteration. | ||||
| save_optimizer (bool): Whether to save optimizer state dict. Default: True. | save_optimizer (bool): Whether to save optimizer state dict. Default: True. | ||||
| save_dir (str): Output directory to save best checkpoint. | save_dir (str): Output directory to save best checkpoint. | ||||
| restore_best (bool): Whether to restore the best checkpoint after training. | |||||
| """ | """ | ||||
| PRIORITY = Priority.LOW | PRIORITY = Priority.LOW | ||||
| @@ -228,6 +229,7 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| save_optimizer=True, | save_optimizer=True, | ||||
| save_dir=None, | save_dir=None, | ||||
| save_file_name=None, | save_file_name=None, | ||||
| restore_best=False, | |||||
| interval=0): | interval=0): | ||||
| assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | ||||
| super().__init__( | super().__init__( | ||||
| @@ -241,6 +243,7 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| self._best_metric = None | self._best_metric = None | ||||
| self._best_ckpt_file = None | self._best_ckpt_file = None | ||||
| self.save_file_name = save_file_name | self.save_file_name = save_file_name | ||||
| self.restore_best = restore_best | |||||
| def _should_save(self, trainer): | def _should_save(self, trainer): | ||||
| return self._is_best_metric(trainer.metric_values) | return self._is_best_metric(trainer.metric_values) | ||||
| @@ -305,3 +308,7 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| self.logger.warn( | self.logger.warn( | ||||
| 'The state_dict is not available, the best metric value will be affected.' | 'The state_dict is not available, the best metric value will be affected.' | ||||
| ) | ) | ||||
| def after_run(self, trainer): | |||||
| if self.restore_best: | |||||
| self.load_checkpoint(self._best_ckpt_file, trainer) | |||||
| @@ -664,6 +664,12 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| dataset = self.to_task_dataset(torch_dataset, mode) | dataset = self.to_task_dataset(torch_dataset, mode) | ||||
| return dataset | return dataset | ||||
| def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): | |||||
| return build_optimizer(self.model, cfg=cfg, default_args=default_args) | |||||
| def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): | |||||
| return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||||
| def create_optimizer_and_scheduler(self): | def create_optimizer_and_scheduler(self): | ||||
| """ Create optimizer and lr scheduler | """ Create optimizer and lr scheduler | ||||
| @@ -680,7 +686,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| optim_options = {} | optim_options = {} | ||||
| if optimizer_cfg is not None: | if optimizer_cfg is not None: | ||||
| optim_options = optimizer_cfg.pop('options', {}) | optim_options = optimizer_cfg.pop('options', {}) | ||||
| optimizer = build_optimizer(self.model, cfg=optimizer_cfg) | |||||
| optimizer = self.build_optimizer(cfg=optimizer_cfg) | |||||
| if lr_scheduler is None: | if lr_scheduler is None: | ||||
| lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) | lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) | ||||
| @@ -691,7 +697,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| if lr_scheduler_cfg is not None: | if lr_scheduler_cfg is not None: | ||||
| assert optimizer is not None | assert optimizer is not None | ||||
| lr_options = lr_scheduler_cfg.pop('options', {}) | lr_options = lr_scheduler_cfg.pop('options', {}) | ||||
| lr_scheduler = build_lr_scheduler( | |||||
| lr_scheduler = self.build_lr_scheduler( | |||||
| cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer}) | cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer}) | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||