yuze.zyz 3 years ago
parent
commit
ddf8daf0a0
3 changed files with 11 additions and 10 deletions
  1. +7
    -5
      modelscope/trainers/hooks/checkpoint_hook.py
  2. +2
    -3
      modelscope/utils/checkpoint.py
  3. +2
    -2
      modelscope/utils/hub.py

+ 7
- 5
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -117,20 +117,20 @@ class CheckpointHook(Hook):
for i, hook in enumerate(trainer.hooks):
# hook: Hook
key = f'{hook.__class__}-{i}'
if key in meta:
if key in meta and hasattr(hook, 'load_state_dict'):
hook.load_state_dict(meta[key])
else:
trainer.logger(
trainer.logger.warn(
f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.'
)

version = meta.get('modelscope')
if version != __version__:
trainer.logger(
trainer.logger.warn(
f'The modelscope version of loaded checkpoint does not match the runtime version. '
f'The saved version: {version}, runtime version: {__version__}'
)
trainer.logger(
trainer.logger.warn(
f'Checkpoint {filename} saving time: {meta.get("time")}')
return meta

@@ -149,7 +149,8 @@ class CheckpointHook(Hook):
'rng_state': self.rng_state,
}
for i, hook in enumerate(trainer.hooks):
meta[f'{hook.__class__}-{i}'] = hook.state_dict()
if hasattr(hook, 'state_dict'):
meta[f'{hook.__class__}-{i}'] = hook.state_dict()

save_checkpoint(
trainer.model,
@@ -239,6 +240,7 @@ class BestCkptSaverHook(CheckpointHook):
self.rule = rule
self._best_metric = None
self._best_ckpt_file = None
self.save_file_name = save_file_name

def _should_save(self, trainer):
return self._is_best_metric(trainer.metric_values)


+ 2
- 3
modelscope/utils/checkpoint.py View File

@@ -87,9 +87,8 @@ def save_checkpoint(model: torch.nn.Module,
checkpoint['optimizer'][name] = optim.state_dict()

# save lr_scheduler state dict in the checkpoint
assert isinstance(lr_scheduler, _LRScheduler), \
f'lr_scheduler to be saved should be a subclass of _LRScheduler, current is : {lr_scheduler.__class__}'
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'):
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
else:
checkpoint = weights_to_cpu(model.state_dict())



+ 2
- 2
modelscope/utils/hub.py View File

@@ -142,8 +142,8 @@ def parse_label_mapping(model_dir):
id2label = config[ConfigFields.preprocessor].id2label
label2id = {label: id for id, label in id2label.items()}

if label2id is None:
config_path = os.path.join(model_dir, 'config.json')
config_path = os.path.join(model_dir, 'config.json')
if label2id is None and os.path.exists(config_path):
config = Config.from_file(config_path)
if hasattr(config, 'label2id'):
label2id = config.label2id


Loading…
Cancel
Save