From 357a233ee32bbaec7eaef58f383d86219b3f9cd3 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Tue, 27 Sep 2022 23:03:00 +0800 Subject: [PATCH] [to #42322933] fix bug: checkpoint hook and bestckpthook exists at the same time Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10227608 --- modelscope/trainers/default_config.py | 19 +++++++++++++++++++ modelscope/trainers/trainer.py | 7 ++----- tests/trainers/hooks/test_checkpoint_hook.py | 3 --- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/modelscope/trainers/default_config.py b/modelscope/trainers/default_config.py index 69fdd400..c8f0c7b0 100644 --- a/modelscope/trainers/default_config.py +++ b/modelscope/trainers/default_config.py @@ -1,4 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.config import Config + DEFAULT_CONFIG = { 'train': { 'hooks': [{ @@ -12,3 +15,19 @@ DEFAULT_CONFIG = { }] } } + + +def merge_cfg(cfg: Config): + """Merge the default config into the input cfg. + + This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg. + + @param cfg: The input cfg to be merged into. + """ + cfg.merge_from_dict(DEFAULT_CONFIG, force=False) + # pop duplicate hook + + if any(['BestCkptSaverHook' == hook['type'] for hook in cfg.train.hooks]): + cfg.train.hooks = list( + filter(lambda hook: hook['type'] != 'CheckpointHook', + cfg.train.hooks)) diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index d3675720..a01d9b59 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -41,7 +41,7 @@ from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, init_dist, set_random_seed) from .base import BaseTrainer from .builder import TRAINERS -from .default_config import DEFAULT_CONFIG +from .default_config import merge_cfg from .hooks.hook import Hook from .parallel.builder import build_parallel from .parallel.utils import is_parallel @@ -114,7 +114,7 @@ class EpochBasedTrainer(BaseTrainer): super().__init__(cfg_file, arg_parse_fn) # add default config - self.cfg.merge_from_dict(self._get_default_config(), force=False) + merge_cfg(self.cfg) self.cfg = self.rebuild_config(self.cfg) if 'cfg_options' in kwargs: @@ -951,9 +951,6 @@ class EpochBasedTrainer(BaseTrainer): stage_hook_infos.append(info) return '\n'.join(stage_hook_infos) - def _get_default_config(self): - return DEFAULT_CONFIG - def worker_init_fn(worker_id, num_workers, rank, seed): # The seed of each worker equals to diff --git a/tests/trainers/hooks/test_checkpoint_hook.py b/tests/trainers/hooks/test_checkpoint_hook.py index c694ece6..e7f2d33c 100644 --- a/tests/trainers/hooks/test_checkpoint_hook.py +++ b/tests/trainers/hooks/test_checkpoint_hook.py @@ -204,9 +204,6 @@ class BestCkptSaverHookTest(unittest.TestCase): trainer = build_trainer(trainer_name, kwargs) trainer.train() results_files = os.listdir(self.tmp_dir) - self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) - self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) - self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', results_files)