Browse Source

[to #42322933] support restore best checkpoint after training

1. 支持训练完成后自动恢复best ckpt,方便在不同测试集上进行测试
2. build_optimizer/build_lr_scheduler改为成员函数,方便重载(如模型分层lr)
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10348255
master
pangda yingda.chen 3 years ago
parent
commit
4bd72e528a
2 changed files with 15 additions and 2 deletions
  1. +7
    -0
      modelscope/trainers/hooks/checkpoint_hook.py
  2. +8
    -2
      modelscope/trainers/trainer.py

+ 7
- 0
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -216,6 +216,7 @@ class BestCkptSaverHook(CheckpointHook):
by_epoch (bool): Save best checkpoints by epoch or by iteration.
save_optimizer (bool): Whether to save optimizer state dict. Default: True.
save_dir (str): Output directory to save best checkpoint.
restore_best (bool): Whether to restore the best checkpoint after training.
"""

PRIORITY = Priority.LOW
@@ -228,6 +229,7 @@ class BestCkptSaverHook(CheckpointHook):
save_optimizer=True,
save_dir=None,
save_file_name=None,
restore_best=False,
interval=0):
assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.'
super().__init__(
@@ -241,6 +243,7 @@ class BestCkptSaverHook(CheckpointHook):
self._best_metric = None
self._best_ckpt_file = None
self.save_file_name = save_file_name
self.restore_best = restore_best

def _should_save(self, trainer):
return self._is_best_metric(trainer.metric_values)
@@ -305,3 +308,7 @@ class BestCkptSaverHook(CheckpointHook):
self.logger.warn(
'The state_dict is not available, the best metric value will be affected.'
)

def after_run(self, trainer):
if self.restore_best:
self.load_checkpoint(self._best_ckpt_file, trainer)

+ 8
- 2
modelscope/trainers/trainer.py View File

@@ -664,6 +664,12 @@ class EpochBasedTrainer(BaseTrainer):
dataset = self.to_task_dataset(torch_dataset, mode)
return dataset

def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
return build_optimizer(self.model, cfg=cfg, default_args=default_args)

def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None):
return build_lr_scheduler(cfg=cfg, default_args=default_args)

def create_optimizer_and_scheduler(self):
""" Create optimizer and lr scheduler

@@ -680,7 +686,7 @@ class EpochBasedTrainer(BaseTrainer):
optim_options = {}
if optimizer_cfg is not None:
optim_options = optimizer_cfg.pop('options', {})
optimizer = build_optimizer(self.model, cfg=optimizer_cfg)
optimizer = self.build_optimizer(cfg=optimizer_cfg)

if lr_scheduler is None:
lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None)
@@ -691,7 +697,7 @@ class EpochBasedTrainer(BaseTrainer):
if lr_scheduler_cfg is not None:
assert optimizer is not None
lr_options = lr_scheduler_cfg.pop('options', {})
lr_scheduler = build_lr_scheduler(
lr_scheduler = self.build_lr_scheduler(
cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer})

self.optimizer = optimizer


Loading…
Cancel
Save