Merge pull request !17761 from changzherui/add_ckpt_infotags/v1.3.0
| @@ -32,6 +32,7 @@ from ...common.tensor import Tensor | |||
| _cur_dir = os.getcwd() | |||
| _save_dir = _cur_dir | |||
| _info_list = ["epoch_num", "step_num"] | |||
| def _chg_ckpt_file_name_if_same_exist(directory, prefix): | |||
| @@ -82,6 +83,8 @@ class CheckpointConfig: | |||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. | |||
| saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation | |||
| with the network in training, the initial value of saved_network will be saved. Default: None. | |||
| append_info (List): The information save to checkpoint file. Support "epoch_num"、"step_num"、and dict. | |||
| The key of dict must be str, the value of dict must be one of int float and bool. Default: None. | |||
| enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption | |||
| is not required. Default: None. | |||
| enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption | |||
| @@ -131,6 +134,7 @@ class CheckpointConfig: | |||
| integrated_save=True, | |||
| async_save=False, | |||
| saved_network=None, | |||
| append_info=None, | |||
| enc_key=None, | |||
| enc_mode='AES-GCM'): | |||
| @@ -166,6 +170,7 @@ class CheckpointConfig: | |||
| self._integrated_save = Validator.check_bool(integrated_save) | |||
| self._async_save = Validator.check_bool(async_save) | |||
| self._saved_network = saved_network | |||
| self._append_dict = self._handle_append_info(append_info) | |||
| self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) | |||
| self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) | |||
| @@ -214,6 +219,11 @@ class CheckpointConfig: | |||
| """Get the value of _enc_mode""" | |||
| return self._enc_mode | |||
| @property | |||
| def append_dict(self): | |||
| """Get the value of append_dict.""" | |||
| return self._append_dict | |||
| def get_checkpoint_policy(self): | |||
| """Get the policy of checkpoint.""" | |||
| checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, | |||
| @@ -224,6 +234,36 @@ class CheckpointConfig: | |||
| return checkpoint_policy | |||
| @staticmethod | |||
| def _handle_append_info(append_info): | |||
| """Handle ckpt append info.""" | |||
| if append_info is None or append_info == []: | |||
| return None | |||
| if not isinstance(append_info, list): | |||
| raise TypeError(f"The type of append_info must list, but got {str(type(append_info))}.") | |||
| handle_append_info = {} | |||
| if "epoch_num" in append_info: | |||
| handle_append_info["epoch_num"] = 0 | |||
| if "step_num" in append_info: | |||
| handle_append_info["step_num"] = 0 | |||
| dict_num = 0 | |||
| for element in append_info: | |||
| if not isinstance(element, str) and not isinstance(element, dict): | |||
| raise TypeError(f"The type of append_info element must be str or dict, but got {str(type(element))}.") | |||
| if isinstance(element, str) and element not in _info_list: | |||
| raise TypeError(f"The type of append_info element must be in {_info_list}, but got {element}.") | |||
| if isinstance(element, dict): | |||
| dict_num += 1 | |||
| if dict_num > 1: | |||
| raise TypeError(f"The element of append_info must has only one dict.") | |||
| for key, value in element.items(): | |||
| if isinstance(key, str) and isinstance(value, (int, float, bool)): | |||
| handle_append_info[key] = value | |||
| else: | |||
| raise TypeError(f"The type of dict in append_info must be key: str, value: int or float.") | |||
| return handle_append_info | |||
| class ModelCheckpoint(Callback): | |||
| """ | |||
| @@ -273,6 +313,9 @@ class ModelCheckpoint(Callback): | |||
| # get existing checkpoint files | |||
| self._manager = CheckpointManager() | |||
| self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) | |||
| self._append_dict = self._config.append_dict or {} | |||
| self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0 | |||
| self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0 | |||
| self._graph_saved = False | |||
| self._need_flush_from_cache = True | |||
| @@ -370,10 +413,13 @@ class ModelCheckpoint(Callback): | |||
| if context.get_context("enable_ge"): | |||
| set_cur_net(cb_params.train_network) | |||
| cb_params.train_network.exec_checkpoint_graph() | |||
| if "epoch_num" in self._append_dict: | |||
| self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num | |||
| if "step_num" in self._append_dict: | |||
| self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num | |||
| network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network | |||
| save_checkpoint(network, cur_file, self._config.integrated_save, | |||
| self._config.async_save, self._config.enc_key, self._config.enc_mode) | |||
| save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, | |||
| self._append_dict, self._config.enc_key, self._config.enc_mode) | |||
| self._latest_ckpt_file_name = cur_file | |||
| @@ -442,19 +488,19 @@ class CheckpointManager: | |||
| def keep_one_ckpoint_per_minutes(self, minutes, cur_time): | |||
| """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" | |||
| movs = [] | |||
| del_list = [] | |||
| oldest_file = '' | |||
| oldest_time = cur_time | |||
| for ck_file in self._ckpoint_filelist: | |||
| modify_time = os.path.getmtime(ck_file) | |||
| if cur_time - modify_time < 60 * minutes: | |||
| movs.append(ck_file) | |||
| del_list.append(ck_file) | |||
| if modify_time < oldest_time: | |||
| oldest_time = modify_time | |||
| oldest_file = ck_file | |||
| for mv_file in movs: | |||
| for mv_file in del_list: | |||
| if mv_file == oldest_file: | |||
| continue | |||
| self.remove_ckpoint_file(mv_file) | |||
| @@ -14,15 +14,16 @@ | |||
| # ============================================================================ | |||
| """Model and parameters serialization.""" | |||
| import os | |||
| import sys | |||
| import stat | |||
| import math | |||
| import shutil | |||
| import time | |||
| import copy | |||
| import threading | |||
| from threading import Thread, Lock | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| @@ -189,7 +190,8 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): | |||
| raise e | |||
| def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"): | |||
| def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, | |||
| async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"): | |||
| """ | |||
| Saves checkpoint info to a specified file. | |||
| @@ -201,6 +203,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | |||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | |||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | |||
| append_dict (dict): Additional information that needs to be saved. The key of dict must be str, | |||
| the value of dict must be one of int float and bool. Default: None | |||
| enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption | |||
| is not required. Default: None. | |||
| enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption | |||
| @@ -221,6 +225,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | |||
| integrated_save = Validator.check_bool(integrated_save) | |||
| async_save = Validator.check_bool(async_save) | |||
| append_dict = _check_append_dict(append_dict) | |||
| enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) | |||
| enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) | |||
| @@ -245,6 +250,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| param_list.append(each_param) | |||
| save_obj = param_list | |||
| if append_dict: | |||
| append_info_list = [] | |||
| for k_name, value in append_dict.items(): | |||
| append_info_list.append({"name": k_name, "data": Tensor(value)}) | |||
| save_obj.extend(append_info_list) | |||
| data_list = {} | |||
| with _ckpt_mutex: | |||
| for param in save_obj: | |||
| @@ -282,6 +293,17 @@ def _check_param_prefix(filter_prefix, param_name): | |||
| return False | |||
| def _check_append_dict(append_dict): | |||
| if append_dict is None: | |||
| return append_dict | |||
| if not isinstance(append_dict, dict): | |||
| raise TypeError(f"The type of append_dict must dict, but got {str(type(append_dict))}.") | |||
| if not all(isinstance(ele, str) for ele in append_dict.keys()) or \ | |||
| not all(isinstance(ele, (int, float, bool)) for ele in append_dict.values()): | |||
| raise TypeError(f"The type of element in append_dict must be key: str, value: int or float.") | |||
| return append_dict | |||
| def load(file_name): | |||
| """ | |||
| Load MindIR. | |||
| @@ -456,8 +478,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False): | |||
| Args: | |||
| net (Cell): Cell network. | |||
| parameter_dict (dict): Parameter dictionary. | |||
| strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter | |||
| in the param_dict into net with the same suffix. Default: False | |||
| strict_load (bool): Whether to strict load the parameter into net. False: if some parameters in the net | |||
| not loaded, it will remove some parameter's prefix name continue to load. Default: False | |||
| Raises: | |||
| TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. | |||
| @@ -1270,6 +1292,18 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||
| load_param_into_net(network, param_dict) | |||
| def async_ckpt_thread_status(): | |||
| """ | |||
| Get async save checkpoint thread status. | |||
| Returns: | |||
| True, Asynchronous save checkpoint thread is running. | |||
| False, Asynchronous save checkpoint thread is not executing. | |||
| """ | |||
| thr_list = threading.enumerate() | |||
| return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list] | |||
| def _check_predict_strategy(predict_strategy): | |||
| """Check predict strategy.""" | |||
| def _check_int_list(arg): | |||
| @@ -120,6 +120,30 @@ def test_save_checkpoint_for_list(): | |||
| save_checkpoint(parameter_list, ckpt_file_name) | |||
| def test_save_checkpoint_for_list_append_info(): | |||
| """ test save_checkpoint for list append info""" | |||
| parameter_list = [] | |||
| one_param = {} | |||
| param1 = {} | |||
| param2 = {} | |||
| one_param['name'] = "param_test" | |||
| one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) | |||
| param1['name'] = "param" | |||
| param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32) | |||
| param2['name'] = "new_param" | |||
| param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32) | |||
| parameter_list.append(one_param) | |||
| parameter_list.append(param1) | |||
| parameter_list.append(param2) | |||
| append_dict = {"lr": 0.01, "epoch": 20, "train": True} | |||
| if os.path.exists('./parameters.ckpt'): | |||
| os.chmod('./parameters.ckpt', stat.S_IWRITE) | |||
| os.remove('./parameters.ckpt') | |||
| ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') | |||
| save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict) | |||
| def test_load_checkpoint_error_filename(): | |||
| ckpt_file_name = 1 | |||
| with pytest.raises(ValueError): | |||
| @@ -130,7 +154,7 @@ def test_load_checkpoint(): | |||
| ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') | |||
| par_dict = load_checkpoint(ckpt_file_name) | |||
| assert len(par_dict) == 3 | |||
| assert len(par_dict) == 6 | |||
| assert par_dict['param_test'].name == 'param_test' | |||
| assert par_dict['param_test'].data.dtype == mstype.float32 | |||
| assert par_dict['param_test'].data.shape == (1, 3, 224, 224) | |||