From a60c944c330b0281583cf3f321d0e1f877ffb1a7 Mon Sep 17 00:00:00 2001 From: changzherui Date: Thu, 4 Mar 2021 17:40:40 +0800 Subject: [PATCH] modify export mindir model --- mindspore/core/load_mindir/load_model.cc | 9 ++++----- mindspore/train/serialization.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index 37742f2a1e..1b457e5c08 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -133,19 +133,18 @@ std::shared_ptr LoadMindIR(const std::string &file_name, bool is_lite // Load parameter into graph if (endsWith(abs_path_buff, "_graph.mindir")) { - char *mindir_name, delimiter = '/'; - mindir_name = strrchr(abs_path_buff, delimiter); - int path_len = strlen(abs_path_buff) - strlen(mindir_name) + 1; + int path_len = strlen(abs_path_buff) - strlen("graph.mindir"); memcpy(abs_path, abs_path_buff, path_len); abs_path[path_len] = '\0'; - snprintf(abs_path, sizeof(abs_path), "variables"); + snprintf(abs_path + path_len, sizeof(abs_path), "variables"); std::ifstream ifs(abs_path); if (ifs.good()) { MS_LOG(DEBUG) << "MindIR file has variables path, load parameter into graph."; string path = abs_path; get_all_files(path, &files); } else { - MS_LOG(ERROR) << "MindIR graph has not variable path. "; + MS_LOG(ERROR) << "MindIR graph has not variable path, load failed"; + return nullptr; } int file_size = files.size(); diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index cde19714c1..498e54de3f 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -638,9 +638,12 @@ def _save_mindir(net, file_name, *inputs): else: logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") # save parameter + file_prefix = file_name.split("/")[-1] + if file_prefix.endswith(".mindir"): + file_prefix = file_prefix[:-7] current_path = os.path.abspath(file_name) dirname = os.path.dirname(current_path) - data_path = dirname + "/variables" + data_path = dirname + "/" + file_prefix + "_variables" if os.path.exists(data_path): shutil.rmtree(data_path) os.makedirs(data_path, exist_ok=True) @@ -675,7 +678,7 @@ def _save_mindir(net, file_name, *inputs): # save graph del model.graph.parameter[:] - graph_file_name = file_name + "_graph.mindir" + graph_file_name = dirname + "/" + file_prefix + "_graph.mindir" with open(graph_file_name, 'wb') as f: os.chmod(graph_file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(model.SerializeToString()) @@ -1147,7 +1150,6 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): def _load_single_param(ckpt_file_name, param_name): """Load a parameter from checkpoint.""" - logger.info("Execute the process of loading checkpoint files.") checkpoint_list = Checkpoint() try: @@ -1155,7 +1157,8 @@ def _load_single_param(ckpt_file_name, param_name): pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: - logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) + logger.error("Failed to read the checkpoint file `%s` during load single parameter," + " please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter = None @@ -1189,8 +1192,7 @@ def _load_single_param(ckpt_file_name, param_name): param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) - break - logger.info("Loading checkpoint files process is finished.") + break except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)