Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10227608master
| @@ -1,4 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from modelscope.utils.config import Config | |||||
| DEFAULT_CONFIG = { | DEFAULT_CONFIG = { | ||||
| 'train': { | 'train': { | ||||
| 'hooks': [{ | 'hooks': [{ | ||||
| @@ -12,3 +15,19 @@ DEFAULT_CONFIG = { | |||||
| }] | }] | ||||
| } | } | ||||
| } | } | ||||
| def merge_cfg(cfg: Config): | |||||
| """Merge the default config into the input cfg. | |||||
| This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg. | |||||
| @param cfg: The input cfg to be merged into. | |||||
| """ | |||||
| cfg.merge_from_dict(DEFAULT_CONFIG, force=False) | |||||
| # pop duplicate hook | |||||
| if any(['BestCkptSaverHook' == hook['type'] for hook in cfg.train.hooks]): | |||||
| cfg.train.hooks = list( | |||||
| filter(lambda hook: hook['type'] != 'CheckpointHook', | |||||
| cfg.train.hooks)) | |||||
| @@ -41,7 +41,7 @@ from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, | |||||
| init_dist, set_random_seed) | init_dist, set_random_seed) | ||||
| from .base import BaseTrainer | from .base import BaseTrainer | ||||
| from .builder import TRAINERS | from .builder import TRAINERS | ||||
| from .default_config import DEFAULT_CONFIG | |||||
| from .default_config import merge_cfg | |||||
| from .hooks.hook import Hook | from .hooks.hook import Hook | ||||
| from .parallel.builder import build_parallel | from .parallel.builder import build_parallel | ||||
| from .parallel.utils import is_parallel | from .parallel.utils import is_parallel | ||||
| @@ -114,7 +114,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| super().__init__(cfg_file, arg_parse_fn) | super().__init__(cfg_file, arg_parse_fn) | ||||
| # add default config | # add default config | ||||
| self.cfg.merge_from_dict(self._get_default_config(), force=False) | |||||
| merge_cfg(self.cfg) | |||||
| self.cfg = self.rebuild_config(self.cfg) | self.cfg = self.rebuild_config(self.cfg) | ||||
| if 'cfg_options' in kwargs: | if 'cfg_options' in kwargs: | ||||
| @@ -951,9 +951,6 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| stage_hook_infos.append(info) | stage_hook_infos.append(info) | ||||
| return '\n'.join(stage_hook_infos) | return '\n'.join(stage_hook_infos) | ||||
| def _get_default_config(self): | |||||
| return DEFAULT_CONFIG | |||||
| def worker_init_fn(worker_id, num_workers, rank, seed): | def worker_init_fn(worker_id, num_workers, rank, seed): | ||||
| # The seed of each worker equals to | # The seed of each worker equals to | ||||
| @@ -204,9 +204,6 @@ class BestCkptSaverHookTest(unittest.TestCase): | |||||
| trainer = build_trainer(trainer_name, kwargs) | trainer = build_trainer(trainer_name, kwargs) | ||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | ||||
| results_files) | results_files) | ||||