Merge pull request !5482 from liuyang/md_save_checkpointtags/v1.0.0
| @@ -23,7 +23,7 @@ import mindspore.context as context | |||
| from mindspore import log as logger | |||
| from mindspore._checkparam import check_bool, check_int_non_negative | |||
| from mindspore.train._utils import _make_directory | |||
| from mindspore.train.serialization import _exec_save_checkpoint, _save_graph | |||
| from mindspore.train.serialization import save_checkpoint, _save_graph | |||
| from ._callback import Callback, set_cur_net | |||
| @@ -306,8 +306,8 @@ class ModelCheckpoint(Callback): | |||
| set_cur_net(cb_params.train_network) | |||
| cb_params.train_network.exec_checkpoint_graph() | |||
| _exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, | |||
| self._config.async_save) | |||
| save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, | |||
| self._config.async_save) | |||
| self._latest_ckpt_file_name = cur_file | |||
| @@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list): | |||
| raise RuntimeError(e.__str__()) | |||
| def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): | |||
| def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False): | |||
| """ | |||
| Saves checkpoint info to a specified file. | |||
| Args: | |||
| parameter_list (list): Parameters list, each element is a dictionary | |||
| like {"name":xx, "type":xx, "shape":xx, "data":xx}. | |||
| save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary, | |||
| like {"name":xx, "type":xx, "shape":xx, "data":xx}.) | |||
| ckpt_file_name (str): Checkpoint file name. | |||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. | |||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | |||
| Raises: | |||
| TypeError: If the parameter save_obj is not nn.Cell or list type. | |||
| RuntimeError: Failed to save the Checkpoint file. | |||
| """ | |||
| if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): | |||
| raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | |||
| logger.info("Execute save checkpoint process.") | |||
| if isinstance(save_obj, nn.Cell): | |||
| save_obj.init_parameters_data() | |||
| param_dict = {} | |||
| for _, param in save_obj.parameters_and_names(): | |||
| param_dict[param.name] = param | |||
| param_list = [] | |||
| for (key, value) in param_dict.items(): | |||
| each_param = {"name": key} | |||
| if isinstance(value.data, Tensor): | |||
| param_data = value.data | |||
| else: | |||
| param_data = Tensor(value.data) | |||
| # in automatic model parallel scenario, some parameters were spliteds to all the devices, | |||
| # which should be combined before saving | |||
| if integrated_save and key in save_obj.parameter_layout_dict: | |||
| param_data = _get_merged_param_data(save_obj, key, param_data) | |||
| each_param["data"] = param_data | |||
| param_list.append(each_param) | |||
| save_obj = param_list | |||
| data_list = {} | |||
| with _ckpt_mutex: | |||
| for param in parameter_list: | |||
| for param in save_obj: | |||
| key = param["name"] | |||
| data_list[key] = [] | |||
| if isinstance(param["data"], Parameter): | |||
| @@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): | |||
| thr.start() | |||
| else: | |||
| _exec_save(ckpt_file_name, data_list) | |||
| logger.info("Save checkpoint process finish.") | |||
| @@ -354,39 +383,6 @@ def _save_graph(network, file_name): | |||
| os.chmod(file_name, stat.S_IRUSR) | |||
| def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False): | |||
| """ | |||
| Saves checkpoint for 'ms' backend. | |||
| Args: | |||
| train_network (Network): The train network for training. | |||
| ckpt_file_name (str): The name of checkpoint file. | |||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. | |||
| async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False. | |||
| """ | |||
| train_network.init_parameters_data() | |||
| param_dict = {} | |||
| for _, param in train_network.parameters_and_names(): | |||
| param_dict[param.name] = param | |||
| param_list = [] | |||
| for (key, value) in param_dict.items(): | |||
| each_param = {"name": key} | |||
| if isinstance(value.data, Tensor): | |||
| param_data = value.data | |||
| else: | |||
| param_data = Tensor(value.data) | |||
| # in automatic model parallel scenario, some parameters were spliteds to all the devices, | |||
| # which should be combined before saving | |||
| if integrated_save and key in train_network.parameter_layout_dict: | |||
| param_data = _get_merged_param_data(train_network, key, param_data) | |||
| each_param["data"] = param_data | |||
| param_list.append(each_param) | |||
| save_checkpoint(param_list, ckpt_file_name, async_save) | |||
| def _get_merged_param_data(net, param_name, param_data): | |||
| """ | |||
| Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. | |||
| @@ -18,7 +18,7 @@ import os | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint | |||
| from mindspore.train.serialization import save_checkpoint, load_checkpoint | |||
| from src.config import GatConfig | |||
| from src.dataset import load_and_process | |||
| @@ -98,7 +98,7 @@ def train(): | |||
| val_loss_model = eval_loss | |||
| if os.path.exists("ckpts/gat.ckpt"): | |||
| os.remove("ckpts/gat.ckpt") | |||
| _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt") | |||
| save_checkpoint(train_net.network, "ckpts/gat.ckpt") | |||
| val_acc_max = np.max((val_acc_max, eval_acc)) | |||
| val_loss_min = np.min((val_loss_min, eval_loss)) | |||
| curr_step = 0 | |||
| @@ -20,7 +20,7 @@ import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.train.callback import Callback | |||
| from mindspore.train.serialization import _exec_save_checkpoint | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR | |||
| from .assessment_method import Accuracy | |||
| @@ -53,9 +53,9 @@ class ModelSaveCkpt(Callback): | |||
| self.save_ckpt_step)) | |||
| if os.path.exists(path): | |||
| os.remove(path) | |||
| _exec_save_checkpoint(self.network, os.path.join(self.output_dir, | |||
| "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), | |||
| self.save_ckpt_step))) | |||
| save_checkpoint(self.network, os.path.join(self.output_dir, | |||
| "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), | |||
| self.save_ckpt_step))) | |||
| class LossCallBack(Callback): | |||
| """ | |||
| @@ -113,7 +113,7 @@ class EvalCallBack(Callback): | |||
| eval_model_ckpt_file = "eval_model.ckpt" | |||
| if os.path.exists(eval_model_ckpt_file): | |||
| os.remove(eval_model_ckpt_file) | |||
| _exec_save_checkpoint(self.network, eval_model_ckpt_file) | |||
| save_checkpoint(self.network, eval_model_ckpt_file) | |||
| class BertLearningRate(LearningRateSchedule): | |||
| """ | |||
| @@ -31,7 +31,7 @@ from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train.callback import _CheckpointManager | |||
| from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ | |||
| _exec_save_checkpoint, export, _save_graph | |||
| export, _save_graph | |||
| from ..ut_filter import non_graph_engine | |||
| context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") | |||
| @@ -95,8 +95,8 @@ def test_save_graph(): | |||
| os.remove(output_file) | |||
| def test_save_checkpoint(): | |||
| """ test_save_checkpoint """ | |||
| def test_save_checkpoint_for_list(): | |||
| """ test save_checkpoint for list""" | |||
| parameter_list = [] | |||
| one_param = {} | |||
| param1 = {} | |||
| @@ -280,14 +280,15 @@ def test_load_param_into_net(): | |||
| assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1 | |||
| def test_exec_save_checkpoint(): | |||
| def test_save_checkpoint_for_network(): | |||
| """ test save_checkpoint for network""" | |||
| net = Net() | |||
| loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
| opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) | |||
| loss_net = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(loss_net, opt) | |||
| _exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") | |||
| save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") | |||
| load_checkpoint("new_ckpt.ckpt") | |||