From b0070fe089159ab56af2dd2375945af67b058519 Mon Sep 17 00:00:00 2001 From: changzherui Date: Wed, 24 Feb 2021 12:51:52 +0800 Subject: [PATCH] export large mindir model --- .../transform/express_ir/mindir_exporter.cc | 3 - mindspore/common/api.py | 5 +- mindspore/core/load_mindir/load_model.cc | 111 ++++++++++++++++-- mindspore/train/serialization.py | 100 ++++++++++++++-- 4 files changed, 195 insertions(+), 24 deletions(-) diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 35c10fe2a2..929bd98286 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -193,9 +193,6 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G parameter_proto->set_name(param_name); SetParamToTensorProto(param, parameter_proto); auto tensor = std::dynamic_pointer_cast(param->default_param()); - if (tensor) { - parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); - } } else { mind_ir::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_name); diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 1312e8ee4a..94dee2c62e 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -494,11 +494,12 @@ class _Executor: if graph is None: logger.error("%r graph compile failed.", phase) - if not do_convert: - return phase, True self._auto_parallel_process(obj, phase, is_sink_mode, auto_parallel_mode, *args) + if not do_convert: + return phase, True + # the following GE init process is not needed when use vm or ms backend if enable_ge: self._build_data_graph(obj, phase) diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index ed34247ddc..37742f2a1e 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -14,11 +14,16 @@ * limitations under the License. */ -#include "load_mindir/load_model.h" +#include +#include +#include +#include #include #include #include +#include +#include "load_mindir/load_model.h" #include "load_mindir/anf_model_parser.h" using std::string; @@ -71,20 +76,106 @@ std::shared_ptr> ReadProtoFile(const std::string &file) { return buf; } +bool get_all_files(const std::string &dir_in, std::vector *files) { + if (dir_in.empty()) { + return false; + } + struct stat s; + stat(dir_in.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + return false; + } + DIR *open_dir = opendir(dir_in.c_str()); + if (NULL == open_dir) { + std::exit(EXIT_FAILURE); + } + dirent *p = nullptr; + while ((p = readdir(open_dir)) != nullptr) { + struct stat st; + if (p->d_name[0] != '.') { + std::string name = dir_in + std::string("/") + std::string(p->d_name); + stat(name.c_str(), &st); + if (S_ISDIR(st.st_mode)) { + get_all_files(name, files); + } else if (S_ISREG(st.st_mode)) { + files->push_back(name); + } + } + } + closedir(open_dir); + return true; +} + +int endsWith(string s, string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; } + std::shared_ptr LoadMindIR(const std::string &file_name, bool is_lite) { - auto graphBuf = ReadProtoFile(file_name); - if (graphBuf == nullptr) { - MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str(); - return nullptr; + const char *file_path = reinterpret_cast(file_name.c_str()); + char abs_path_buff[PATH_MAX]; + char abs_path[PATH_MAX]; + + vector files; + +#ifdef _WIN32 + _fullpath(abs_path_buff, file_path, 1024); +#else + if (!realpath(file_path, abs_path_buff)) { + MS_LOG(ERROR) << "Load MindIR get absolute path failed"; } +#endif + // Read graph + std::fstream input_graph(abs_path_buff, std::ios::in | std::ios::binary); + mind_ir::ModelProto origin_model; - try { - auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size(), is_lite); - return graph; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + if (!input_graph || !origin_model.ParseFromIstream(&input_graph)) { + MS_LOG(ERROR) << "Load MindIR file failed."; return nullptr; } + + // Load parameter into graph + if (endsWith(abs_path_buff, "_graph.mindir")) { + char *mindir_name, delimiter = '/'; + mindir_name = strrchr(abs_path_buff, delimiter); + int path_len = strlen(abs_path_buff) - strlen(mindir_name) + 1; + memcpy(abs_path, abs_path_buff, path_len); + abs_path[path_len] = '\0'; + snprintf(abs_path, sizeof(abs_path), "variables"); + std::ifstream ifs(abs_path); + if (ifs.good()) { + MS_LOG(DEBUG) << "MindIR file has variables path, load parameter into graph."; + string path = abs_path; + get_all_files(path, &files); + } else { + MS_LOG(ERROR) << "MindIR graph has not variable path. "; + } + + int file_size = files.size(); + mind_ir::GraphProto *mod_graph = origin_model.mutable_graph(); + for (auto file_index = 0; file_index < file_size; file_index++) { + std::fstream input_param(files[file_index], std::ios::in | std::ios::binary); + mind_ir::GraphProto param_graph; + if (!input_param || !param_graph.ParseFromIstream(&input_param)) { + MS_LOG(ERROR) << "Load param proto file failed."; + return nullptr; + } + + for (int param_index = 0; param_index < param_graph.parameter_size(); param_index++) { + mind_ir::TensorProto *param_proto = mod_graph->add_parameter(); + param_proto->set_name(param_graph.parameter(param_index).name()); + param_proto->set_data_type(param_graph.parameter(param_index).data_type()); + param_proto->set_raw_data(param_graph.parameter(param_index).raw_data()); + for (const auto &dim : param_graph.parameter(param_index).dims()) { + param_proto->add_dims(dim); + } + } + } + } + + MSANFModelParser model_parser; + if (is_lite) { + model_parser.SetLite(); + } + FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model); + return dstgraph_ptr; } std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) { diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index cc50b8fd53..cde19714c1 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -14,16 +14,21 @@ # ============================================================================ """Model and parameters serialization.""" import os +import sys import stat import math +import shutil from threading import Thread, Lock import numpy as np + import mindspore.nn as nn import mindspore.context as context from mindspore import log as logger from mindspore.train.checkpoint_pb2 import Checkpoint from mindspore.train.print_pb2 import Print from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts +from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model +from mindspore.train.mind_ir_pb2 import GraphProto as graph_proto from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter @@ -46,6 +51,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin _ckpt_mutex = Lock() SLICE_SIZE = 512 * 1024 * 1024 +TOTAL_SAVE = 1024 * 1024 def _special_process_par(par, new_par): @@ -423,10 +429,10 @@ def _save_graph(network, file_name): """ logger.info("Execute the process of saving graph.") - graph_proto = network.get_func_graph_proto() - if graph_proto: + graph_pb = network.get_func_graph_proto() + if graph_pb: with open(file_name, "wb") as f: - f.write(graph_proto) + f.write(graph_pb) os.chmod(file_name, stat.S_IRUSR) @@ -569,7 +575,6 @@ def _export(net, file_name, file_format, *inputs): if is_dump_onnx_in_training: net.set_train(mode=False) - net.init_parameters_data() if file_format == 'AIR': phase_name = 'export.air' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) @@ -586,17 +591,94 @@ def _export(net, file_name, file_format, *inputs): os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) elif file_format == 'MINDIR': + _save_mindir(net, file_name, *inputs) + + if is_dump_onnx_in_training: + net.set_train(mode=True) + + +def _save_mindir(net, file_name, *inputs): + """Save MindIR format file.""" + model = mindir_model() + if net._auto_parallel_mode: + phase_name = "predict" + else: phase_name = 'export.mindir' - graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) - onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') + graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, + do_convert=False, auto_parallel_mode=net._auto_parallel_mode) + mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') + + net_dict = net.parameters_dict() + data_total = 0 + save_together = True + for i in net_dict.values(): + data_total += sys.getsizeof(i.data.asnumpy().tobytes())/1024 + if data_total > TOTAL_SAVE: + save_together = False + break + + model.ParseFromString(mindir_stream) + if save_together: + for param_proto in model.graph.parameter: + param_name = param_proto.name[param_proto.name.find(":")+1:] + if param_name in net_dict.keys(): + param_data = net_dict[param_name].data.asnumpy().tobytes() + param_proto.raw_data = param_data + else: + logger.error("The parameter %s in the graph are not in the network.", param_name) + raise ValueError("The parameter in the graph must in the network.") if not file_name.endswith('.mindir'): file_name += ".mindir" + current_path = os.path.abspath(file_name) + dirname = os.path.dirname(current_path) + os.makedirs(dirname, exist_ok=True) with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) - f.write(onnx_stream) + f.write(model.SerializeToString()) + else: + logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") + # save parameter + current_path = os.path.abspath(file_name) + dirname = os.path.dirname(current_path) + data_path = dirname + "/variables" + if os.path.exists(data_path): + shutil.rmtree(data_path) + os.makedirs(data_path, exist_ok=True) + index = 0 + graphproto = graph_proto() + data_size = 0 + for name, param in net_dict.items(): + byte_data = param.data.asnumpy().tobytes() + data_size += sys.getsizeof(byte_data)/1024 + parameter = graphproto.parameter.add() + for param_proto in model.graph.parameter: + if name == param_proto.name[param_proto.name.find(":") + 1:]: + parameter.name = param_proto.name + parameter.data_type = param_proto.data_type + for dim in param_proto.dims: + parameter.dims.append(dim) + break - if is_dump_onnx_in_training: - net.set_train(mode=True) + parameter.raw_data = byte_data + if data_size > TOTAL_SAVE: + data_file_name = data_path + "/" + "data_" + str(index) + with open(data_file_name, "ab") as f: + f.write(graphproto.SerializeToString()) + index += 1 + data_size = 0 + del graphproto.parameter[:] + + if graphproto.parameter: + data_file_name = data_path + "/" + "data_" + str(index) + with open(data_file_name, "ab") as f: + f.write(graphproto.SerializeToString()) + + # save graph + del model.graph.parameter[:] + graph_file_name = file_name + "_graph.mindir" + with open(graph_file_name, 'wb') as f: + os.chmod(graph_file_name, stat.S_IWUSR | stat.S_IRUSR) + f.write(model.SerializeToString()) def _quant_export(network, *inputs, file_format, **kwargs):