|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- import io
- import time
- from collections import OrderedDict
- from typing import Optional
-
- import torch
- from torch.optim import Optimizer
-
- from modelscope import __version__
- from modelscope.fileio import File
-
-
- def weights_to_cpu(state_dict):
- """Copy a model state_dict to cpu.
-
- Args:
- state_dict (OrderedDict): Model weights on GPU.
-
- Returns:
- OrderedDict: Model weights on GPU.
- """
- state_dict_cpu = OrderedDict()
- for key, val in state_dict.items():
- state_dict_cpu[key] = val.cpu()
- # Keep metadata in state_dict
- state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
- return state_dict_cpu
-
-
- def save_checkpoint(model: torch.nn.Module,
- filename: str,
- optimizer: Optional[Optimizer] = None,
- meta: Optional[dict] = None) -> None:
- """Save checkpoint to file.
-
- The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
- ``optimizer``. By default ``meta`` will contain version and time info.
-
- Args:
- model (Module): Module whose params are to be saved.
- filename (str): Checkpoint filename.
- optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
- meta (dict, optional): Metadata to be saved in checkpoint.
- """
- if meta is None:
- meta = {}
- elif not isinstance(meta, dict):
- raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
- meta.update(modescope=__version__, time=time.asctime())
-
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- model = model.module
-
- if hasattr(model, 'CLASSES') and model.CLASSES is not None:
- # save class name to the meta
- meta.update(CLASSES=model.CLASSES)
-
- checkpoint = {
- 'meta': meta,
- 'state_dict': weights_to_cpu(model.state_dict())
- }
- # save optimizer state dict in the checkpoint
- if isinstance(optimizer, Optimizer):
- checkpoint['optimizer'] = optimizer.state_dict()
- elif isinstance(optimizer, dict):
- checkpoint['optimizer'] = {}
- for name, optim in optimizer.items():
- checkpoint['optimizer'][name] = optim.state_dict()
-
- with io.BytesIO() as f:
- torch.save(checkpoint, f)
- File.write(f.getvalue(), filename)
|