From ea4d45c368dbef31e839a5c4d44714bd4e12a2f4 Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Tue, 29 Mar 2022 11:44:20 +0800 Subject: [PATCH] support mindir export in lite converter --- .../ccsrc/transform/express_ir/CMakeLists.txt | 2 +- .../transform/express_ir/mindir_exporter.cc | 8 + mindspore/lite/src/common/file_utils.cc | 35 ++ mindspore/lite/src/common/file_utils.h | 2 + mindspore/lite/test/CMakeLists.txt | 4 + .../tools/common/meta_graph_serializer.cc | 33 +- mindspore/lite/tools/converter/CMakeLists.txt | 5 + mindspore/lite/tools/converter/converter.cc | 24 +- .../lite/tools/converter/converter_flags.cc | 24 +- .../lite/tools/converter/converter_flags.h | 4 + .../parser/caffe/caffe_model_parser.cc | 1 + .../parser/lite_model_parser_creator.h | 36 ++ .../parser/onnx/onnx_model_parser.cc | 1 + .../tools/converter/parser/parser_utils.h | 11 - .../converter/parser/tf/tf_model_parser.cc | 1 + .../parser/tflite/tflite_model_parser.cc | 1 + .../tools/mindir_serializer/CMakeLists.txt | 32 ++ .../mindir_serializer/mindir_serializer.cc | 415 ++++++++++++++++++ .../mindir_serializer/mindir_serializer.h | 78 ++++ 19 files changed, 664 insertions(+), 53 deletions(-) create mode 100644 mindspore/lite/tools/converter/parser/lite_model_parser_creator.h create mode 100644 mindspore/lite/tools/mindir_serializer/CMakeLists.txt create mode 100644 mindspore/lite/tools/mindir_serializer/mindir_serializer.cc create mode 100644 mindspore/lite/tools/mindir_serializer/mindir_serializer.h diff --git a/mindspore/ccsrc/transform/express_ir/CMakeLists.txt b/mindspore/ccsrc/transform/express_ir/CMakeLists.txt index cf280efaff..7b4e415425 100644 --- a/mindspore/ccsrc/transform/express_ir/CMakeLists.txt +++ b/mindspore/ccsrc/transform/express_ir/CMakeLists.txt @@ -1,5 +1,5 @@ file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -file(STRINGS "${CMAKE_SOURCE_DIR}/version.txt" VERSION) +file(STRINGS "${TOP_DIR}/version.txt" VERSION) add_definitions(-DVERSION=\"${VERSION}\") set_property(SOURCE ${_EXPORTER_IR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_EXPRESS) diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 16e94d9cc6..51e7afd9f0 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -31,7 +31,9 @@ #include "include/common/debug/dump_proto.h" #include "utils/ms_utils.h" #include "include/common/utils/utils.h" +#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP #include "frontend/parallel/tensor_layout/tensor_layout.h" +#endif #include "abstract/abstract_function.h" namespace mindspore { @@ -105,7 +107,9 @@ class IrExportBuilder { bool BuildModel(const FuncGraphPtr &func_graph); ModelProtoPtr Model() { return model_; } +#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP void BuildLayout(const FuncGraphPtr &func_graph); +#endif bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); bool BuildFuncGraphAttrs(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); @@ -256,10 +260,12 @@ ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, const Fun return nullptr; } +#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP // Export layout information if (param_layout_fg) { builder_->BuildLayout(param_layout_fg); } +#endif return builder_->Model(); } @@ -278,6 +284,7 @@ void IrExportBuilder::BuildModelInfo() { model_->set_mind_ir_version(mind_ir::Version_MAX); } +#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP void IrExportBuilder::BuildLayout(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); std::vector graph_params = func_graph->parameters(); @@ -315,6 +322,7 @@ void IrExportBuilder::BuildLayout(const FuncGraphPtr &func_graph) { } } } +#endif bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/lite/src/common/file_utils.cc b/mindspore/lite/src/common/file_utils.cc index 46c5620713..337bbc0700 100644 --- a/mindspore/lite/src/common/file_utils.cc +++ b/mindspore/lite/src/common/file_utils.cc @@ -278,5 +278,40 @@ std::string GetDirectory(const std::string &path) { } return dir; } + +bool ParserPathAndModelName(const std::string &output_path, std::string *save_path, std::string *model_name) { + auto pos = output_path.find_last_of('/'); + if (pos == std::string::npos) { + pos = output_path.find_last_of('\\'); + } + std::string tmp_model_name; + if (pos == std::string::npos) { +#ifdef _WIN32 + *save_path = ".\\"; +#else + *save_path = "./"; +#endif + tmp_model_name = output_path; + } else { + *save_path = output_path.substr(0, pos + 1); + tmp_model_name = output_path.substr(pos + 1); + } + *save_path = RealPath(save_path->c_str()); + if (save_path->empty()) { + MS_LOG(DEBUG) << "File path not regular: " << save_path; + return false; + } + auto suffix_pos = tmp_model_name.find_last_of('.'); + if (suffix_pos == std::string::npos) { + *model_name = tmp_model_name; + } else { + if (tmp_model_name.substr(suffix_pos + 1) == "ms") { + *model_name = tmp_model_name.substr(0, suffix_pos); + } else { + *model_name = tmp_model_name; + } + } + return true; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/common/file_utils.h b/mindspore/lite/src/common/file_utils.h index 1b700cc369..6c025f2f9a 100644 --- a/mindspore/lite/src/common/file_utils.h +++ b/mindspore/lite/src/common/file_utils.h @@ -82,6 +82,8 @@ inline int WriteToBin(const std::string &file_path, void *data, const size_t siz } std::string GetDirectory(const std::string &path); + +bool ParserPathAndModelName(const std::string &output_path, std::string *save_path, std::string *model_name); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 2043b5bb9e..ac317628f7 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -222,6 +222,10 @@ if(MSLITE_ENABLE_CONVERTER AND (NOT MSLITE_ENABLE_RUNTIME_CONVERT)) preprocess_mid config_parser_mid coder_mid + ccsrc_debug_common_mid_ + mindir_proto_mid + _mindspore_transform_express_ir_obj + mindir_serializer_mid ) endif() diff --git a/mindspore/lite/tools/common/meta_graph_serializer.cc b/mindspore/lite/tools/common/meta_graph_serializer.cc index ffbf0f4910..d2c80565d2 100644 --- a/mindspore/lite/tools/common/meta_graph_serializer.cc +++ b/mindspore/lite/tools/common/meta_graph_serializer.cc @@ -65,39 +65,10 @@ std::fstream *ReopenFile(const std::string &file_path, std::ios_base::openmode o } // namespace bool MetaGraphSerializer::InitPath(const std::string &output_path) { - this->save_path_.clear(); - this->model_name_.clear(); - auto pos = output_path.find_last_of('/'); - if (pos == std::string::npos) { - pos = output_path.find_last_of('\\'); - } - std::string model_name; - if (pos == std::string::npos) { -#ifdef _WIN32 - this->save_path_ = ".\\"; -#else - this->save_path_ = "./"; -#endif - model_name = output_path; - } else { - this->save_path_ = output_path.substr(0, pos + 1); - model_name = output_path.substr(pos + 1); - } - this->save_path_ = RealPath(this->save_path_.c_str()); - if (this->save_path_.empty()) { - MS_LOG(DEBUG) << "File path not regular: " << this->save_path_; + if (!ParserPathAndModelName(output_path, &this->save_path_, &this->model_name_)) { + MS_LOG(ERROR) << "parser save path and model name from output_path failed."; return false; } - auto suffix_pos = model_name.find_last_of('.'); - if (suffix_pos == std::string::npos) { - this->model_name_ = model_name; - } else { - if (model_name.substr(suffix_pos + 1) == "ms") { - this->model_name_ = model_name.substr(0, suffix_pos); - } else { - this->model_name_ = model_name; - } - } #ifdef _WIN32 save_model_path_ = save_path_ + "\\" + model_name_ + ".ms"; save_data_path_ = save_path_ + "\\" + model_name_ + ".msw"; diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 56bebea377..5f3db33fbd 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -74,6 +74,7 @@ if((NOT WIN32) AND MSLITE_ENABLE_DPICO_ATC_ADAPTER) add_subdirectory(adapter/dpico) endif() add_subdirectory(../anf_exporter anf_exporter) +add_subdirectory(../mindir_serializer mindir_serializer) add_subdirectory(parser/caffe) add_subdirectory(parser/tflite) add_subdirectory(parser/onnx) @@ -301,6 +302,10 @@ target_link_libraries(converter_lite PRIVATE preprocess_mid config_parser_mid coder_mid + ccsrc_debug_common_mid_ + mindir_proto_mid + _mindspore_transform_express_ir_obj + mindir_serializer_mid ) if(MSLITE_ENABLE_ACL) diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 18b1183584..519873c60f 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -40,6 +40,7 @@ #include "src/common/version_manager.h" #include "tools/common/tensor_util.h" #include "include/api/model.h" +#include "tools/mindir_serializer/mindir_serializer.h" namespace mindspore { namespace lite { @@ -148,21 +149,26 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr & return nullptr; } + MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed"); + // funcgraph transform + graph = funcgraph_transform_->Transform(graph, flag.get()); + if (graph == nullptr) { + MS_LOG(ERROR) << "Transform anf graph return nullptr"; + return nullptr; + } + + // export protobuf + auto status = MindIRSerialize(flag, graph); + if (status != RET_OK) { + MS_LOG(WARNING) << "Export to mindir proto return nullptr."; + } + return TransferFuncGraph(flag, graph); } schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr &flag, FuncGraphPtr func_graph) { - MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed"); MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed"); - - // funcgraph compile - func_graph = funcgraph_transform_->Transform(func_graph, flag.get()); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Transform anf graph return nullptr"; - return nullptr; - } - #ifdef MSLITE_ENABLE_GRAPH_KERNEL if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { graphkernel::GraphKernelOptimize(func_graph); diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 92958ce809..6caa37919a 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -94,7 +94,11 @@ Flags::Flags() { ""); #endif AddFlag(&Flags::inferStr, "infer", - "Whether to do pre-inference after convert." + "Whether to do pre-inference after convert. " + "true | false", + "false"); + AddFlag(&Flags::exportMindIR, "exportMindIR", + "Whether to export MindIR pb. " "true | false", "false"); } @@ -353,6 +357,18 @@ int Flags::InitPreInference() { return RET_OK; } +int Flags::InitExportMindIR() { + if (this->exportMindIR == "true") { + this->export_mindir = true; + } else if (this->exportMindIR == "false") { + this->export_mindir = false; + } else { + std::cerr << "INPUT ILLEGAL: exportMindIR must be true|false " << std::endl; + return RET_INPUT_PARAM_INVALID; + } + return RET_OK; +} + int Flags::InitEncrypt() { if (this->encryptionStr == "true") { this->encryption = true; @@ -483,6 +499,12 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "Init pre inference failed." << std::endl; return RET_INPUT_PARAM_INVALID; } + + ret = InitExportMindIR(); + if (ret != RET_OK) { + std::cerr << "Init export mindir failed." << std::endl; + return RET_INPUT_PARAM_INVALID; + } return RET_OK; } Flags::~Flags() { diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 0ca19a0289..9da5bc3a62 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -75,6 +75,8 @@ class Flags : public virtual mindspore::lite::FlagParser { void InitAclDefaultOption(); + int InitExportMindIR(); + int Init(int argc, const char **argv); int PreInit(int argc, const char **argv); @@ -105,6 +107,7 @@ class Flags : public virtual mindspore::lite::FlagParser { std::string encKeyStr; std::string encMode = "AES-GCM"; std::string inferStr; + std::string exportMindIR; #ifdef ENABLE_OPENSSL std::string encryptionStr = "true"; bool encryption = true; @@ -115,6 +118,7 @@ class Flags : public virtual mindspore::lite::FlagParser { bool infer = false; unsigned char encKey[kEncMaxLen]; size_t keyLen = 0; + bool export_mindir = false; lite::quant::CommonQuantParam commonQuantParam; lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 5a5ddcfbfa..706dc0d34f 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -30,6 +30,7 @@ #include "tools/converter/converter_context.h" #include "tools/converter/quant_param_holder.h" #include "tools/converter/parser/parser_utils.h" +#include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/parser/unify_format.h" #include "nnacl/op_base.h" diff --git a/mindspore/lite/tools/converter/parser/lite_model_parser_creator.h b/mindspore/lite/tools/converter/parser/lite_model_parser_creator.h new file mode 100644 index 0000000000..46fe666a5a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/lite_model_parser_creator.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_LITE_MODEL_PARSER_CREATOR_H_ +#define MINDSPORE_LITE_TOOLS_LITE_MODEL_PARSER_CREATOR_H_ + +#include "include/registry/model_parser.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "src/common/log_adapter.h" + +namespace mindspore::lite { +template +converter::ModelParser *LiteModelParserCreator() { + auto *parser = new (std::nothrow) T(); + if (parser == nullptr) { + MS_LOG(ERROR) << "new model parser failed"; + return nullptr; + } + return parser; +} +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_TOOLS_LITE_MODEL_PARSER_CREATOR_H_ diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 9c060c86b2..c632f5ab7a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -34,6 +34,7 @@ #include "tools/converter/parser/onnx/onnx_pad_adjust.h" #include "tools/converter/parser/onnx/onnx_nonzero_adjust.h" #include "tools/converter/parser/parser_utils.h" +#include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/converter/parser/unify_format.h" #include "nnacl/op_base.h" #include "src/common/log_util.h" diff --git a/mindspore/lite/tools/converter/parser/parser_utils.h b/mindspore/lite/tools/converter/parser/parser_utils.h index 4145976a80..e07e6492e3 100644 --- a/mindspore/lite/tools/converter/parser/parser_utils.h +++ b/mindspore/lite/tools/converter/parser/parser_utils.h @@ -21,7 +21,6 @@ #include #include #include "ops/primitive_c.h" -#include "include/registry/model_parser.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "src/common/log_adapter.h" @@ -43,16 +42,6 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod schema::Format dst_format, std::set *has_visited); int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format, schema::Format dst_format, std::set *has_visited); - -template -converter::ModelParser *LiteModelParserCreator() { - auto *parser = new (std::nothrow) T(); - if (parser == nullptr) { - MS_LOG(ERROR) << "new model parser failed"; - return nullptr; - } - return parser; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 271d5330c4..b397d0d539 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -35,6 +35,7 @@ #include "tools/converter/quant_param_holder.h" #include "tools/converter/parser/tf/functionalize_control_op_pass.h" #include "tools/converter/parser/parser_utils.h" +#include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/common/tensor_util.h" #include "src/common/log_util.h" #include "tools/converter/parser/unify_format.h" diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 745ed28ed6..2859d3602d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -30,6 +30,7 @@ #include "tools/converter/converter_flags.h" #include "tools/converter/parser/tflite/tflite_inputs_adjust.h" #include "tools/converter/parser/parser_utils.h" +#include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/converter/parser/unify_format.h" #include "nnacl/op_base.h" #include "src/common/log_util.h" diff --git a/mindspore/lite/tools/mindir_serializer/CMakeLists.txt b/mindspore/lite/tools/mindir_serializer/CMakeLists.txt new file mode 100644 index 0000000000..31b945bda8 --- /dev/null +++ b/mindspore/lite/tools/mindir_serializer/CMakeLists.txt @@ -0,0 +1,32 @@ +file(GLOB_RECURSE MINDIR_EXPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + *.cc + ) +set_property(SOURCE ${MINDIR_EXPORTER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) + +file(GLOB PROTO_FILE "" + ${CORE_DIR}/proto/mind_ir.proto + ${CCSRC_DIR}/utils/*.proto + ) +ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) +add_library(mindir_proto_mid OBJECT ${PROTO_SRCS}) + +add_compile_definitions(MINDIR_EXPORT_TENSOR_LAYOUT_CLIP) +add_compile_definitions(COMMON_DLL) + +set(MINDIR_EXPORT_DIR ${CCSRC_DIR}/transform/express_ir) +add_subdirectory(${MINDIR_EXPORT_DIR} mindir_exporter) + +add_library(mindir_serializer_mid OBJECT + ${MINDIR_EXPORTER_SRC_LIST} + ) + +add_library(ccsrc_debug_common_mid_ OBJECT + ${CCSRC_DIR}/common/debug/common.cc + ) + +target_link_libraries(mindir_serializer_mid + _mindspore_transform_express_ir_obj + ccsrc_debug_common_mid_ + mindir_proto_mid + ) diff --git a/mindspore/lite/tools/mindir_serializer/mindir_serializer.cc b/mindspore/lite/tools/mindir_serializer/mindir_serializer.cc new file mode 100644 index 0000000000..c52e6f1acd --- /dev/null +++ b/mindspore/lite/tools/mindir_serializer/mindir_serializer.cc @@ -0,0 +1,415 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/mindir_serializer/mindir_serializer.h" +#include +#include +#include +#include +#include +#include "mindspore/ccsrc/include/common/debug/dump_proto.h" +#include "mindspore/ccsrc/include/common/utils/utils.h" +#include "src/common/file_utils.h" +#include "tools/converter/parser/parser_utils.h" + +namespace mindspore::lite { +// unit is byte. model size more than 1G need split. +constexpr const size_t TOTAL_SAVE = 1024 * 1024 * 1024; +constexpr const int64_t OFFSET = 64; + +namespace { +bool DeleteDirRecursively(const std::string &dir_name) { + DIR *dir = opendir(dir_name.c_str()); + dirent *dirent = nullptr; + std::vector file_names{}; + while ((dirent = readdir(dir)) != 0) { + if (strcmp(dirent->d_name, ".") != 0 && strcmp(dirent->d_name, "..") != 0) { + file_names.push_back(dirent->d_name); + } + } + for (auto &file_name : file_names) { + auto file_path = dir_name + "/" + file_name; + auto real_file_path = RealPath(file_path.c_str()); + auto result = unlink(real_file_path.c_str()); + if (result != 0) { + MS_LOG(ERROR) << "Delete the file(" << real_file_path << ") failed." << ErrnoToString(errno); + return false; + } + } + return true; +} +} // namespace + +int MindIRSerializer::RemoveQuantParameterHolder(FuncGraphPtr func_graph) { + std::set all_func_graphs = {}; + GetAllFuncGraph(func_graph, &all_func_graphs); + for (auto &graph : all_func_graphs) { + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (cnode->inputs().empty() || cnode->input(0) == nullptr) { + MS_LOG(ERROR) << "the cnode is invalid."; + return lite::RET_NULL_PTR; + } + if (utils::isa(cnode->input(0))) { + MS_LOG(DEBUG) << "call cnode no need to convert primitive."; + return lite::RET_NO_CHANGE; + } + auto value_node = cnode->input(0)->cast(); + if (value_node == nullptr || value_node->value() == nullptr) { + MS_LOG(ERROR) << "value node is invalid."; + return lite::RET_NULL_PTR; + } + auto primitive = value_node->value()->cast(); + if (primitive == nullptr) { + if (utils::isa(value_node->value())) { + MS_LOG(DEBUG) << "is a funcgraph."; + return lite::RET_NO_CHANGE; + } else { + MS_LOG(ERROR) << "the value is not primitive."; + return lite::RET_ERROR; + } + } + primitive->EraseAttr("quant_params"); + } + } + return RET_OK; +} + +int MindIRSerializer::Save(const std::unique_ptr &flag, const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "func_graph is nullptr."; + return RET_NULL_PTR; + } + auto output_file = flag->outputFile; + auto ret = ParserPath(output_file); + if (ret != RET_OK) { + MS_LOG(ERROR) << "parse path failed."; + return ret; + } + + ret = RemoveQuantParameterHolder(func_graph); + if (ret != RET_OK && ret != RET_NO_CHANGE) { + MS_LOG(ERROR) << "remove quant parameter holder failed."; + return ret; + } + + auto proto_string = GetBinaryProtoString(func_graph); + if (proto_string.empty()) { + MS_LOG(ERROR) << "parse proto string failed."; + return RET_NULL_PTR; + } + + if (!model_proto_.ParseFromString(proto_string)) { + MS_LOG(ERROR) << "parse model proto from string failed."; + return RET_NULL_PTR; + } + + ret = ParamDict(func_graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "parse param form funcgraph failed."; + return ret; + } + + ret = IfSaveTogether(&save_together_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "error occur when check condition of saving together."; + return ret; + } + + if (save_together_) { + ret = SaveMindIRTogether(); + } else { + ret = SplitSave(); + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "save mindir weight failed."; + return ret; + } + return RET_OK; +} + +int MindIRSerializer::SaveMindIRTogether() { + for (auto ¶m_proto : *(model_proto_.mutable_graph()->mutable_parameter())) { + std::string proto_name = param_proto.name(); + auto para = GetFgParaAccordingToProtoName(proto_name); + if (para == nullptr) { + return RET_ERROR; + } + if (!para->has_default()) { + continue; + } + auto data = para->default_param()->cast(); + param_proto.clear_raw_data(); + param_proto.set_raw_data(data->data_c(), static_cast(data->data().nbytes())); + } + + return SaveProtoToFile(&model_proto_, save_model_path_); +} + +int MindIRSerializer::CreateParameterDir() { +#ifdef _WIN32 + dir_name_ = save_path_ + "\\" + model_name_ + "_variables"; +#else + dir_name_ = save_path_ + "/" + model_name_ + "_variables"; +#endif + fs_ = system::Env::GetFileSystem(); + if (fs_ == nullptr) { + MS_LOG(ERROR) << "create file system failed."; + return RET_NULL_PTR; + } + + if (fs_->FileExist(dir_name_)) { + if (!DeleteDirRecursively(dir_name_)) { + return RET_ERROR; + } + } + + if (!fs_->CreateDir(dir_name_)) { + MS_LOG(ERROR) << "create dir failed."; + return RET_ERROR; + } + + ChangeFileMode(dir_name_, S_IWUSR | S_IRUSR | S_IXUSR); + return RET_OK; +} + +std::shared_ptr MindIRSerializer::GetFgParaAccordingToProtoName(const std::string &proto_name) { + auto beg_pos = proto_name.find_first_of(':') + 1; + if (beg_pos >= proto_name.size()) { + MS_LOG(ERROR) << "begin pos exceed proto name length."; + return nullptr; + } + auto name = proto_name.substr(beg_pos); + if (param_dict_.find(name) == param_dict_.end()) { + MS_LOG(ERROR) << "param proto name: " << name << " is not in param dict."; + return nullptr; + } + return param_dict_.at(name); +} + +int MindIRSerializer::ChangeParaDataFile(const std::string &file) { + data_fs_->close(); + delete data_fs_; + auto real_path = CreateExternalPath(file); + if (fs_->FileExist(real_path)) { + fs_->DeleteFile(real_path); + } + ChangeFileMode(real_path, S_IWUSR); + data_fs_ = OpenFile(real_path, std::ios::app); + char front_info[OFFSET]{0}; + front_info[0] = IsSystemLittleEndidan(); + data_fs_->write(front_info, OFFSET); + return RET_OK; +} + +bool MindIRSerializer::IsSystemLittleEndidan() { + int check = 0x01; + auto address = reinterpret_cast(&check); + return *address == 0x01; +} + +int MindIRSerializer::GetDataFile(const std::string &data_file_name, std::ofstream *fout, int64_t *parameter_size, + int64_t *offset) { + *offset = OFFSET; + std::shared_ptr fs = system::Env::GetFileSystem(); + if (fs == nullptr) { + MS_LOG(ERROR) << "create file system failed."; + return RET_NULL_PTR; + } + if (fs->FileExist(data_file_name)) { + ChangeFileMode(data_file_name, S_IWUSR); + } + + std::byte place_holder[OFFSET]; + fout = new std::ofstream; + fout->write(reinterpret_cast(place_holder), *offset); + + return RET_OK; +} + +std::string MindIRSerializer::CreateExternalPath(const std::string &external_file) { + dir_path_ = RealPath(dir_name_.c_str()); + std::string external_local_path{}; +#ifdef _WIN32 + external_local_path = dir_path_ + "\\" + external_file; +#else + external_local_path = dir_path_ + "/" + external_file; +#endif + return external_local_path; +} + +int MindIRSerializer::SplitSave() { + MS_LOG(DEBUG) << "Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately."; + int ret = CreateParameterDir(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "create parameter dir failed."; + return RET_ERROR; + } + + int index = 0; + std::string external_local = model_name_ + "_data_" + std::to_string(index); + auto external_local_path = CreateExternalPath(external_local); + if (fs_->FileExist(external_local_path)) { + fs_->DeleteFile(external_local_path); + } + int64_t parameter_size = 0; + int64_t offset = OFFSET; + + data_fs_ = OpenFile(external_local_path, std::ios::out | std::ios::binary | std::ios::trunc); + if (data_fs_ == nullptr) { + MS_LOG(ERROR) << "Open " << external_local_path << " failed"; + return false; + } + + for (auto ¶m_proto : *(model_proto_.mutable_graph()->mutable_parameter())) { + std::string proto_name = param_proto.name(); + auto para = GetFgParaAccordingToProtoName(proto_name); + if (para == nullptr) { + return RET_ERROR; + } + if (!para->has_default()) { + continue; + } + auto data = para->default_param()->cast(); + int64_t data_length = static_cast(data->data().nbytes()); + int64_t append_size = 0; + if (data_length % OFFSET != 0) { + append_size = OFFSET - (data_length % OFFSET); + parameter_size += (append_size + data_length); + } else { + parameter_size += data_length; + } + if (parameter_size > static_cast(TOTAL_SAVE)) { + index++; + external_local = model_name_ + "data_" + std::to_string(index); + ret = ChangeParaDataFile(external_local); + if (ret != RET_OK) { + MS_LOG(ERROR) << "change parameter data file failed."; + return ret; + } + parameter_size = OFFSET; + } + *(param_proto.mutable_external_data()->mutable_location()) = external_local; + param_proto.mutable_external_data()->set_length(parameter_size); + param_proto.mutable_external_data()->set_offset(append_size); + data_fs_->write(static_cast(data->data_c()), data_length); + auto append_data = new char[append_size]; + if (append_data == nullptr) { + return RET_NULL_PTR; + } + offset += (data_length + append_size); + data_fs_->write(append_data, append_size); + delete[] append_data; + } + + return SaveProtoToFile(&model_proto_, save_model_path_); +} + +int MindIRSerializer::ParserPath(const std::string &output_path) { + if (!ParserPathAndModelName(output_path, &save_path_, &model_name_)) { + MS_LOG(ERROR) << "parser save path and model name from output_path failed."; + return RET_ERROR; + } +#ifdef _WIN32 + save_model_path_ = save_path_ + "\\" + model_name_ + ".mindir"; +#else + save_model_path_ = save_path_ + "/" + model_name_ + ".mindir"; +#endif + return RET_OK; +} + +int MindIRSerializer::ParamDict(const FuncGraphPtr &func_graph) { + std::set all_func_graphs = {}; + GetAllFuncGraph(func_graph, &all_func_graphs); + for (auto &fg : all_func_graphs) { + for (auto ¶ : fg->parameters()) { + if (!para->isa()) { + MS_LOG(ERROR) << "fg parameters contains non-parameter type node."; + return RET_ERROR; + } + auto para_node = para->cast(); + param_dict_[para->ToString()] = para_node; + } + } + return RET_OK; +} + +int MindIRSerializer::IfSaveTogether(bool *save_together) { + size_t data_total = model_proto_.ByteSizeLong(); + for (auto ¶m_proto : model_proto_.graph().parameter()) { + std::string proto_name = param_proto.name(); + auto para = GetFgParaAccordingToProtoName(proto_name); + if (para == nullptr) { + return RET_ERROR; + } + if (!para->has_default()) { + continue; + } + auto tensor = std::dynamic_pointer_cast(para->default_param()); + if (tensor == nullptr) { + MS_LOG(ERROR) << "param node default_param is not tensor."; + return RET_ERROR; + } + data_total += tensor->Size(); + } + if (data_total > TOTAL_SAVE) { + *save_together = false; + } else { + *save_together = true; + } + return RET_OK; +} + +int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) { + auto realpath = Common::CreatePrefixPath(output_file, true); + if (!realpath.has_value()) { + MS_LOG(ERROR) << "Get real path of file " << output_file << " failed."; + return RET_ERROR; + } + + ChangeFileMode(realpath.value(), S_IWUSR); + std::ofstream fout(realpath.value()); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno); + return RET_ERROR; + } + + if (!model_proto->SerializeToOstream(&fout)) { + MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value(); + fout.close(); + return RET_ERROR; + } + fout.close(); + ChangeFileMode(realpath.value(), S_IRUSR); + return RET_OK; +} + +int MindIRSerialize(const std::unique_ptr &flag, const FuncGraphPtr &func_graph) { + if (!flag->export_mindir) { + return RET_OK; + } +#if defined(SYSTEM_ENV_WINDOWS) + MS_LOG(WARNING) << "mindir serialize not support windows now."; + return RET_NOT_SUPPORT; +#endif + mindspore::lite::MindIRSerializer serializer; + return serializer.Save(flag, func_graph); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/mindir_serializer/mindir_serializer.h b/mindspore/lite/tools/mindir_serializer/mindir_serializer.h new file mode 100644 index 0000000000..36b7df2432 --- /dev/null +++ b/mindspore/lite/tools/mindir_serializer/mindir_serializer.h @@ -0,0 +1,78 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_ +#define MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_ + +#include +#include +#include +#include +#include +#include "mindspore/core/ir/func_graph.h" +#include "tools/converter/converter_context.h" +#include "tools/converter/converter_flags.h" +#include "proto/mind_ir.pb.h" +#include "mindspore/core/utils/system/env.h" + +namespace mindspore::lite { +class MindIRSerializer { + public: + MindIRSerializer() = default; + virtual ~MindIRSerializer() { + if (data_fs_ != nullptr) { + data_fs_->close(); + delete data_fs_; + data_fs_ = nullptr; + } + } + int Save(const std::unique_ptr &flag, const FuncGraphPtr &func_graph); + + private: + int ParserPath(const std::string &output_path); + int IfSaveTogether(bool *save_together); + int SaveMindIRTogether(); + int SplitSave(); + int SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file); + + private: + int ParamDict(const FuncGraphPtr &func_graph); + int CreateParameterDir(); + std::shared_ptr GetFgParaAccordingToProtoName(const std::string &proto_name); + int ChangeParaDataFile(const std::string &file); + bool IsSystemLittleEndidan(); + int GetDataFile(const std::string &data_file_name, std::ofstream *fout, int64_t *parameter_size, int64_t *offset); + std::string CreateExternalPath(const std::string &external_file); + int RemoveQuantParameterHolder(FuncGraphPtr func_graph); + + private: + std::string model_name_; + std::string save_path_; + std::string save_model_path_; + std::string dir_name_; + std::string dir_path_; + bool save_together_ = true; + mind_ir::ModelProto model_proto_; + std::unordered_map param_dict_{}; + std::unordered_map para_proto_dict_{}; + std::fstream *data_fs_ = nullptr; + std::shared_ptr fs_{}; +}; + +// export func_graph +int MindIRSerialize(const std::unique_ptr &flag, const FuncGraphPtr &func_graph); +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_