You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checkpoint.py 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import time
  4. from collections import OrderedDict
  5. from typing import Optional
  6. import torch
  7. from torch.optim import Optimizer
  8. from modelscope import __version__
  9. from modelscope.fileio import File
  10. def weights_to_cpu(state_dict):
  11. """Copy a model state_dict to cpu.
  12. Args:
  13. state_dict (OrderedDict): Model weights on GPU.
  14. Returns:
  15. OrderedDict: Model weights on GPU.
  16. """
  17. state_dict_cpu = OrderedDict()
  18. for key, val in state_dict.items():
  19. state_dict_cpu[key] = val.cpu()
  20. # Keep metadata in state_dict
  21. state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
  22. return state_dict_cpu
  23. def save_checkpoint(model: torch.nn.Module,
  24. filename: str,
  25. optimizer: Optional[Optimizer] = None,
  26. meta: Optional[dict] = None) -> None:
  27. """Save checkpoint to file.
  28. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
  29. ``optimizer``. By default ``meta`` will contain version and time info.
  30. Args:
  31. model (Module): Module whose params are to be saved.
  32. filename (str): Checkpoint filename.
  33. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
  34. meta (dict, optional): Metadata to be saved in checkpoint.
  35. """
  36. if meta is None:
  37. meta = {}
  38. elif not isinstance(meta, dict):
  39. raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
  40. meta.update(modescope=__version__, time=time.asctime())
  41. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  42. model = model.module
  43. if hasattr(model, 'CLASSES') and model.CLASSES is not None:
  44. # save class name to the meta
  45. meta.update(CLASSES=model.CLASSES)
  46. checkpoint = {
  47. 'meta': meta,
  48. 'state_dict': weights_to_cpu(model.state_dict())
  49. }
  50. # save optimizer state dict in the checkpoint
  51. if isinstance(optimizer, Optimizer):
  52. checkpoint['optimizer'] = optimizer.state_dict()
  53. elif isinstance(optimizer, dict):
  54. checkpoint['optimizer'] = {}
  55. for name, optim in optimizer.items():
  56. checkpoint['optimizer'][name] = optim.state_dict()
  57. with io.BytesIO() as f:
  58. torch.save(checkpoint, f)
  59. File.write(f.getvalue(), filename)