|
|
|
@@ -117,20 +117,20 @@ class CheckpointHook(Hook): |
|
|
|
for i, hook in enumerate(trainer.hooks): |
|
|
|
# hook: Hook |
|
|
|
key = f'{hook.__class__}-{i}' |
|
|
|
if key in meta: |
|
|
|
if key in meta and hasattr(hook, 'load_state_dict'): |
|
|
|
hook.load_state_dict(meta[key]) |
|
|
|
else: |
|
|
|
trainer.logger( |
|
|
|
trainer.logger.warn( |
|
|
|
f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' |
|
|
|
) |
|
|
|
|
|
|
|
version = meta.get('modelscope') |
|
|
|
if version != __version__: |
|
|
|
trainer.logger( |
|
|
|
trainer.logger.warn( |
|
|
|
f'The modelscope version of loaded checkpoint does not match the runtime version. ' |
|
|
|
f'The saved version: {version}, runtime version: {__version__}' |
|
|
|
) |
|
|
|
trainer.logger( |
|
|
|
trainer.logger.warn( |
|
|
|
f'Checkpoint {filename} saving time: {meta.get("time")}') |
|
|
|
return meta |
|
|
|
|
|
|
|
@@ -149,7 +149,8 @@ class CheckpointHook(Hook): |
|
|
|
'rng_state': self.rng_state, |
|
|
|
} |
|
|
|
for i, hook in enumerate(trainer.hooks): |
|
|
|
meta[f'{hook.__class__}-{i}'] = hook.state_dict() |
|
|
|
if hasattr(hook, 'state_dict'): |
|
|
|
meta[f'{hook.__class__}-{i}'] = hook.state_dict() |
|
|
|
|
|
|
|
save_checkpoint( |
|
|
|
trainer.model, |
|
|
|
@@ -239,6 +240,7 @@ class BestCkptSaverHook(CheckpointHook): |
|
|
|
self.rule = rule |
|
|
|
self._best_metric = None |
|
|
|
self._best_ckpt_file = None |
|
|
|
self.save_file_name = save_file_name |
|
|
|
|
|
|
|
def _should_save(self, trainer): |
|
|
|
return self._is_best_metric(trainer.metric_values) |
|
|
|
|