From ddf8daf0a0c0ee296cea6df704648c176074df0a Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Mon, 19 Sep 2022 20:41:39 +0800 Subject: [PATCH] [to #42322933] Fix bug in release Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10173975 --- modelscope/trainers/hooks/checkpoint_hook.py | 12 +++++++----- modelscope/utils/checkpoint.py | 5 ++--- modelscope/utils/hub.py | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index a9b793d4..220929b8 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -117,20 +117,20 @@ class CheckpointHook(Hook): for i, hook in enumerate(trainer.hooks): # hook: Hook key = f'{hook.__class__}-{i}' - if key in meta: + if key in meta and hasattr(hook, 'load_state_dict'): hook.load_state_dict(meta[key]) else: - trainer.logger( + trainer.logger.warn( f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' ) version = meta.get('modelscope') if version != __version__: - trainer.logger( + trainer.logger.warn( f'The modelscope version of loaded checkpoint does not match the runtime version. ' f'The saved version: {version}, runtime version: {__version__}' ) - trainer.logger( + trainer.logger.warn( f'Checkpoint {filename} saving time: {meta.get("time")}') return meta @@ -149,7 +149,8 @@ class CheckpointHook(Hook): 'rng_state': self.rng_state, } for i, hook in enumerate(trainer.hooks): - meta[f'{hook.__class__}-{i}'] = hook.state_dict() + if hasattr(hook, 'state_dict'): + meta[f'{hook.__class__}-{i}'] = hook.state_dict() save_checkpoint( trainer.model, @@ -239,6 +240,7 @@ class BestCkptSaverHook(CheckpointHook): self.rule = rule self._best_metric = None self._best_ckpt_file = None + self.save_file_name = save_file_name def _should_save(self, trainer): return self._is_best_metric(trainer.metric_values) diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index 8d8c2b2f..a9d7f396 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -87,9 +87,8 @@ def save_checkpoint(model: torch.nn.Module, checkpoint['optimizer'][name] = optim.state_dict() # save lr_scheduler state dict in the checkpoint - assert isinstance(lr_scheduler, _LRScheduler), \ - f'lr_scheduler to be saved should be a subclass of _LRScheduler, current is : {lr_scheduler.__class__}' - checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'): + checkpoint['lr_scheduler'] = lr_scheduler.state_dict() else: checkpoint = weights_to_cpu(model.state_dict()) diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index cf114b5e..2dbe7045 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -142,8 +142,8 @@ def parse_label_mapping(model_dir): id2label = config[ConfigFields.preprocessor].id2label label2id = {label: id for id, label in id2label.items()} - if label2id is None: - config_path = os.path.join(model_dir, 'config.json') + config_path = os.path.join(model_dir, 'config.json') + if label2id is None and os.path.exists(config_path): config = Config.from_file(config_path) if hasattr(config, 'label2id'): label2id = config.label2id