From 4bd72e528ad9938e131908c5a67920666fcdcae1 Mon Sep 17 00:00:00 2001 From: pangda Date: Tue, 11 Oct 2022 11:14:34 +0800 Subject: [PATCH] [to #42322933] support restore best checkpoint after training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 支持训练完成后自动恢复best ckpt,方便在不同测试集上进行测试 2. build_optimizer/build_lr_scheduler改为成员函数,方便重载(如模型分层lr) Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10348255 --- modelscope/trainers/hooks/checkpoint_hook.py | 7 +++++++ modelscope/trainers/trainer.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index 220929b8..c9f51a88 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -216,6 +216,7 @@ class BestCkptSaverHook(CheckpointHook): by_epoch (bool): Save best checkpoints by epoch or by iteration. save_optimizer (bool): Whether to save optimizer state dict. Default: True. save_dir (str): Output directory to save best checkpoint. + restore_best (bool): Whether to restore the best checkpoint after training. """ PRIORITY = Priority.LOW @@ -228,6 +229,7 @@ class BestCkptSaverHook(CheckpointHook): save_optimizer=True, save_dir=None, save_file_name=None, + restore_best=False, interval=0): assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' super().__init__( @@ -241,6 +243,7 @@ class BestCkptSaverHook(CheckpointHook): self._best_metric = None self._best_ckpt_file = None self.save_file_name = save_file_name + self.restore_best = restore_best def _should_save(self, trainer): return self._is_best_metric(trainer.metric_values) @@ -305,3 +308,7 @@ class BestCkptSaverHook(CheckpointHook): self.logger.warn( 'The state_dict is not available, the best metric value will be affected.' ) + + def after_run(self, trainer): + if self.restore_best: + self.load_checkpoint(self._best_ckpt_file, trainer) diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index a01d9b59..4c21d63f 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -664,6 +664,12 @@ class EpochBasedTrainer(BaseTrainer): dataset = self.to_task_dataset(torch_dataset, mode) return dataset + def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): + return build_optimizer(self.model, cfg=cfg, default_args=default_args) + + def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): + return build_lr_scheduler(cfg=cfg, default_args=default_args) + def create_optimizer_and_scheduler(self): """ Create optimizer and lr scheduler @@ -680,7 +686,7 @@ class EpochBasedTrainer(BaseTrainer): optim_options = {} if optimizer_cfg is not None: optim_options = optimizer_cfg.pop('options', {}) - optimizer = build_optimizer(self.model, cfg=optimizer_cfg) + optimizer = self.build_optimizer(cfg=optimizer_cfg) if lr_scheduler is None: lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) @@ -691,7 +697,7 @@ class EpochBasedTrainer(BaseTrainer): if lr_scheduler_cfg is not None: assert optimizer is not None lr_options = lr_scheduler_cfg.pop('options', {}) - lr_scheduler = build_lr_scheduler( + lr_scheduler = self.build_lr_scheduler( cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer}) self.optimizer = optimizer