Browse Source

support mindir export in lite converter

pull/1/head
mengyuanli 4 years ago
parent
commit
ea4d45c368
19 changed files with 664 additions and 53 deletions
  1. +1
    -1
      mindspore/ccsrc/transform/express_ir/CMakeLists.txt
  2. +8
    -0
      mindspore/ccsrc/transform/express_ir/mindir_exporter.cc
  3. +35
    -0
      mindspore/lite/src/common/file_utils.cc
  4. +2
    -0
      mindspore/lite/src/common/file_utils.h
  5. +4
    -0
      mindspore/lite/test/CMakeLists.txt
  6. +2
    -31
      mindspore/lite/tools/common/meta_graph_serializer.cc
  7. +5
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  8. +15
    -9
      mindspore/lite/tools/converter/converter.cc
  9. +23
    -1
      mindspore/lite/tools/converter/converter_flags.cc
  10. +4
    -0
      mindspore/lite/tools/converter/converter_flags.h
  11. +1
    -0
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc
  12. +36
    -0
      mindspore/lite/tools/converter/parser/lite_model_parser_creator.h
  13. +1
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  14. +0
    -11
      mindspore/lite/tools/converter/parser/parser_utils.h
  15. +1
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  16. +1
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  17. +32
    -0
      mindspore/lite/tools/mindir_serializer/CMakeLists.txt
  18. +415
    -0
      mindspore/lite/tools/mindir_serializer/mindir_serializer.cc
  19. +78
    -0
      mindspore/lite/tools/mindir_serializer/mindir_serializer.h

+ 1
- 1
mindspore/ccsrc/transform/express_ir/CMakeLists.txt View File

@@ -1,5 +1,5 @@
file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
file(STRINGS "${CMAKE_SOURCE_DIR}/version.txt" VERSION)
file(STRINGS "${TOP_DIR}/version.txt" VERSION)
add_definitions(-DVERSION=\"${VERSION}\")
set_property(SOURCE ${_EXPORTER_IR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_EXPRESS)

+ 8
- 0
mindspore/ccsrc/transform/express_ir/mindir_exporter.cc View File

@@ -31,7 +31,9 @@
#include "include/common/debug/dump_proto.h"
#include "utils/ms_utils.h"
#include "include/common/utils/utils.h"
#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
#include "frontend/parallel/tensor_layout/tensor_layout.h"
#endif
#include "abstract/abstract_function.h"

namespace mindspore {
@@ -105,7 +107,9 @@ class IrExportBuilder {
bool BuildModel(const FuncGraphPtr &func_graph);
ModelProtoPtr Model() { return model_; }

#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
void BuildLayout(const FuncGraphPtr &func_graph);
#endif

bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
bool BuildFuncGraphAttrs(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
@@ -256,10 +260,12 @@ ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, const Fun
return nullptr;
}

#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
// Export layout information
if (param_layout_fg) {
builder_->BuildLayout(param_layout_fg);
}
#endif
return builder_->Model();
}

@@ -278,6 +284,7 @@ void IrExportBuilder::BuildModelInfo() {
model_->set_mind_ir_version(mind_ir::Version_MAX);
}

#ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
void IrExportBuilder::BuildLayout(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> graph_params = func_graph->parameters();
@@ -315,6 +322,7 @@ void IrExportBuilder::BuildLayout(const FuncGraphPtr &func_graph) {
}
}
}
#endif

bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);


+ 35
- 0
mindspore/lite/src/common/file_utils.cc View File

@@ -278,5 +278,40 @@ std::string GetDirectory(const std::string &path) {
}
return dir;
}

bool ParserPathAndModelName(const std::string &output_path, std::string *save_path, std::string *model_name) {
auto pos = output_path.find_last_of('/');
if (pos == std::string::npos) {
pos = output_path.find_last_of('\\');
}
std::string tmp_model_name;
if (pos == std::string::npos) {
#ifdef _WIN32
*save_path = ".\\";
#else
*save_path = "./";
#endif
tmp_model_name = output_path;
} else {
*save_path = output_path.substr(0, pos + 1);
tmp_model_name = output_path.substr(pos + 1);
}
*save_path = RealPath(save_path->c_str());
if (save_path->empty()) {
MS_LOG(DEBUG) << "File path not regular: " << save_path;
return false;
}
auto suffix_pos = tmp_model_name.find_last_of('.');
if (suffix_pos == std::string::npos) {
*model_name = tmp_model_name;
} else {
if (tmp_model_name.substr(suffix_pos + 1) == "ms") {
*model_name = tmp_model_name.substr(0, suffix_pos);
} else {
*model_name = tmp_model_name;
}
}
return true;
}
} // namespace lite
} // namespace mindspore

+ 2
- 0
mindspore/lite/src/common/file_utils.h View File

@@ -82,6 +82,8 @@ inline int WriteToBin(const std::string &file_path, void *data, const size_t siz
}

std::string GetDirectory(const std::string &path);

bool ParserPathAndModelName(const std::string &output_path, std::string *save_path, std::string *model_name);
} // namespace lite
} // namespace mindspore



+ 4
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -222,6 +222,10 @@ if(MSLITE_ENABLE_CONVERTER AND (NOT MSLITE_ENABLE_RUNTIME_CONVERT))
preprocess_mid
config_parser_mid
coder_mid
ccsrc_debug_common_mid_
mindir_proto_mid
_mindspore_transform_express_ir_obj
mindir_serializer_mid
)
endif()



+ 2
- 31
mindspore/lite/tools/common/meta_graph_serializer.cc View File

@@ -65,39 +65,10 @@ std::fstream *ReopenFile(const std::string &file_path, std::ios_base::openmode o
} // namespace

bool MetaGraphSerializer::InitPath(const std::string &output_path) {
this->save_path_.clear();
this->model_name_.clear();
auto pos = output_path.find_last_of('/');
if (pos == std::string::npos) {
pos = output_path.find_last_of('\\');
}
std::string model_name;
if (pos == std::string::npos) {
#ifdef _WIN32
this->save_path_ = ".\\";
#else
this->save_path_ = "./";
#endif
model_name = output_path;
} else {
this->save_path_ = output_path.substr(0, pos + 1);
model_name = output_path.substr(pos + 1);
}
this->save_path_ = RealPath(this->save_path_.c_str());
if (this->save_path_.empty()) {
MS_LOG(DEBUG) << "File path not regular: " << this->save_path_;
if (!ParserPathAndModelName(output_path, &this->save_path_, &this->model_name_)) {
MS_LOG(ERROR) << "parser save path and model name from output_path failed.";
return false;
}
auto suffix_pos = model_name.find_last_of('.');
if (suffix_pos == std::string::npos) {
this->model_name_ = model_name;
} else {
if (model_name.substr(suffix_pos + 1) == "ms") {
this->model_name_ = model_name.substr(0, suffix_pos);
} else {
this->model_name_ = model_name;
}
}
#ifdef _WIN32
save_model_path_ = save_path_ + "\\" + model_name_ + ".ms";
save_data_path_ = save_path_ + "\\" + model_name_ + ".msw";


+ 5
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -74,6 +74,7 @@ if((NOT WIN32) AND MSLITE_ENABLE_DPICO_ATC_ADAPTER)
add_subdirectory(adapter/dpico)
endif()
add_subdirectory(../anf_exporter anf_exporter)
add_subdirectory(../mindir_serializer mindir_serializer)
add_subdirectory(parser/caffe)
add_subdirectory(parser/tflite)
add_subdirectory(parser/onnx)
@@ -301,6 +302,10 @@ target_link_libraries(converter_lite PRIVATE
preprocess_mid
config_parser_mid
coder_mid
ccsrc_debug_common_mid_
mindir_proto_mid
_mindspore_transform_express_ir_obj
mindir_serializer_mid
)

if(MSLITE_ENABLE_ACL)


+ 15
- 9
mindspore/lite/tools/converter/converter.cc View File

@@ -40,6 +40,7 @@
#include "src/common/version_manager.h"
#include "tools/common/tensor_util.h"
#include "include/api/model.h"
#include "tools/mindir_serializer/mindir_serializer.h"

namespace mindspore {
namespace lite {
@@ -148,21 +149,26 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
return nullptr;
}

MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
// funcgraph transform
graph = funcgraph_transform_->Transform(graph, flag.get());
if (graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
}

// export protobuf
auto status = MindIRSerialize(flag, graph);
if (status != RET_OK) {
MS_LOG(WARNING) << "Export to mindir proto return nullptr.";
}

return TransferFuncGraph(flag, graph);
}

schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag,
FuncGraphPtr func_graph) {
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");

// funcgraph compile
func_graph = funcgraph_transform_->Transform(func_graph, flag.get());
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
}

#ifdef MSLITE_ENABLE_GRAPH_KERNEL
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
graphkernel::GraphKernelOptimize(func_graph);


+ 23
- 1
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -94,7 +94,11 @@ Flags::Flags() {
"");
#endif
AddFlag(&Flags::inferStr, "infer",
"Whether to do pre-inference after convert."
"Whether to do pre-inference after convert. "
"true | false",
"false");
AddFlag(&Flags::exportMindIR, "exportMindIR",
"Whether to export MindIR pb. "
"true | false",
"false");
}
@@ -353,6 +357,18 @@ int Flags::InitPreInference() {
return RET_OK;
}

int Flags::InitExportMindIR() {
if (this->exportMindIR == "true") {
this->export_mindir = true;
} else if (this->exportMindIR == "false") {
this->export_mindir = false;
} else {
std::cerr << "INPUT ILLEGAL: exportMindIR must be true|false " << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}

int Flags::InitEncrypt() {
if (this->encryptionStr == "true") {
this->encryption = true;
@@ -483,6 +499,12 @@ int Flags::Init(int argc, const char **argv) {
std::cerr << "Init pre inference failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}

ret = InitExportMindIR();
if (ret != RET_OK) {
std::cerr << "Init export mindir failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
Flags::~Flags() {


+ 4
- 0
mindspore/lite/tools/converter/converter_flags.h View File

@@ -75,6 +75,8 @@ class Flags : public virtual mindspore::lite::FlagParser {

void InitAclDefaultOption();

int InitExportMindIR();

int Init(int argc, const char **argv);

int PreInit(int argc, const char **argv);
@@ -105,6 +107,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string encKeyStr;
std::string encMode = "AES-GCM";
std::string inferStr;
std::string exportMindIR;
#ifdef ENABLE_OPENSSL
std::string encryptionStr = "true";
bool encryption = true;
@@ -115,6 +118,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
bool infer = false;
unsigned char encKey[kEncMaxLen];
size_t keyLen = 0;
bool export_mindir = false;

lite::quant::CommonQuantParam commonQuantParam;
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;


+ 1
- 0
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc View File

@@ -30,6 +30,7 @@
#include "tools/converter/converter_context.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/lite_model_parser_creator.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/parser/unify_format.h"
#include "nnacl/op_base.h"


+ 36
- 0
mindspore/lite/tools/converter/parser/lite_model_parser_creator.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_LITE_MODEL_PARSER_CREATOR_H_
#define MINDSPORE_LITE_TOOLS_LITE_MODEL_PARSER_CREATOR_H_

#include "include/registry/model_parser.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/log_adapter.h"

namespace mindspore::lite {
template <class T>
converter::ModelParser *LiteModelParserCreator() {
auto *parser = new (std::nothrow) T();
if (parser == nullptr) {
MS_LOG(ERROR) << "new model parser failed";
return nullptr;
}
return parser;
}
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_LITE_MODEL_PARSER_CREATOR_H_

+ 1
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -34,6 +34,7 @@
#include "tools/converter/parser/onnx/onnx_pad_adjust.h"
#include "tools/converter/parser/onnx/onnx_nonzero_adjust.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/lite_model_parser_creator.h"
#include "tools/converter/parser/unify_format.h"
#include "nnacl/op_base.h"
#include "src/common/log_util.h"


+ 0
- 11
mindspore/lite/tools/converter/parser/parser_utils.h View File

@@ -21,7 +21,6 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "include/registry/model_parser.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/log_adapter.h"
@@ -43,16 +42,6 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);

template <class T>
converter::ModelParser *LiteModelParserCreator() {
auto *parser = new (std::nothrow) T();
if (parser == nullptr) {
MS_LOG(ERROR) << "new model parser failed";
return nullptr;
}
return parser;
}
} // namespace lite
} // namespace mindspore



+ 1
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -35,6 +35,7 @@
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/lite_model_parser_creator.h"
#include "tools/common/tensor_util.h"
#include "src/common/log_util.h"
#include "tools/converter/parser/unify_format.h"


+ 1
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -30,6 +30,7 @@
#include "tools/converter/converter_flags.h"
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/lite_model_parser_creator.h"
#include "tools/converter/parser/unify_format.h"
#include "nnacl/op_base.h"
#include "src/common/log_util.h"


+ 32
- 0
mindspore/lite/tools/mindir_serializer/CMakeLists.txt View File

@@ -0,0 +1,32 @@
file(GLOB_RECURSE MINDIR_EXPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.cc
)
set_property(SOURCE ${MINDIR_EXPORTER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)

file(GLOB PROTO_FILE ""
${CORE_DIR}/proto/mind_ir.proto
${CCSRC_DIR}/utils/*.proto
)
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
add_library(mindir_proto_mid OBJECT ${PROTO_SRCS})

add_compile_definitions(MINDIR_EXPORT_TENSOR_LAYOUT_CLIP)
add_compile_definitions(COMMON_DLL)

set(MINDIR_EXPORT_DIR ${CCSRC_DIR}/transform/express_ir)
add_subdirectory(${MINDIR_EXPORT_DIR} mindir_exporter)

add_library(mindir_serializer_mid OBJECT
${MINDIR_EXPORTER_SRC_LIST}
)

add_library(ccsrc_debug_common_mid_ OBJECT
${CCSRC_DIR}/common/debug/common.cc
)

target_link_libraries(mindir_serializer_mid
_mindspore_transform_express_ir_obj
ccsrc_debug_common_mid_
mindir_proto_mid
)

+ 415
- 0
mindspore/lite/tools/mindir_serializer/mindir_serializer.cc View File

@@ -0,0 +1,415 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/mindir_serializer/mindir_serializer.h"
#include <sys/stat.h>
#include <dirent.h>
#include <fstream>
#include <vector>
#include <set>
#include "mindspore/ccsrc/include/common/debug/dump_proto.h"
#include "mindspore/ccsrc/include/common/utils/utils.h"
#include "src/common/file_utils.h"
#include "tools/converter/parser/parser_utils.h"

namespace mindspore::lite {
// unit is byte. model size more than 1G need split.
constexpr const size_t TOTAL_SAVE = 1024 * 1024 * 1024;
constexpr const int64_t OFFSET = 64;

namespace {
bool DeleteDirRecursively(const std::string &dir_name) {
DIR *dir = opendir(dir_name.c_str());
dirent *dirent = nullptr;
std::vector<std::string> file_names{};
while ((dirent = readdir(dir)) != 0) {
if (strcmp(dirent->d_name, ".") != 0 && strcmp(dirent->d_name, "..") != 0) {
file_names.push_back(dirent->d_name);
}
}
for (auto &file_name : file_names) {
auto file_path = dir_name + "/" + file_name;
auto real_file_path = RealPath(file_path.c_str());
auto result = unlink(real_file_path.c_str());
if (result != 0) {
MS_LOG(ERROR) << "Delete the file(" << real_file_path << ") failed." << ErrnoToString(errno);
return false;
}
}
return true;
}
} // namespace

int MindIRSerializer::RemoveQuantParameterHolder(FuncGraphPtr func_graph) {
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(func_graph, &all_func_graphs);
for (auto &graph : all_func_graphs) {
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->inputs().empty() || cnode->input(0) == nullptr) {
MS_LOG(ERROR) << "the cnode is invalid.";
return lite::RET_NULL_PTR;
}
if (utils::isa<CNodePtr>(cnode->input(0))) {
MS_LOG(DEBUG) << "call cnode no need to convert primitive.";
return lite::RET_NO_CHANGE;
}
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr || value_node->value() == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return lite::RET_NULL_PTR;
}
auto primitive = value_node->value()->cast<PrimitivePtr>();
if (primitive == nullptr) {
if (utils::isa<FuncGraphPtr>(value_node->value())) {
MS_LOG(DEBUG) << "is a funcgraph.";
return lite::RET_NO_CHANGE;
} else {
MS_LOG(ERROR) << "the value is not primitive.";
return lite::RET_ERROR;
}
}
primitive->EraseAttr("quant_params");
}
}
return RET_OK;
}

int MindIRSerializer::Save(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func_graph is nullptr.";
return RET_NULL_PTR;
}
auto output_file = flag->outputFile;
auto ret = ParserPath(output_file);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parse path failed.";
return ret;
}

ret = RemoveQuantParameterHolder(func_graph);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "remove quant parameter holder failed.";
return ret;
}

auto proto_string = GetBinaryProtoString(func_graph);
if (proto_string.empty()) {
MS_LOG(ERROR) << "parse proto string failed.";
return RET_NULL_PTR;
}

if (!model_proto_.ParseFromString(proto_string)) {
MS_LOG(ERROR) << "parse model proto from string failed.";
return RET_NULL_PTR;
}

ret = ParamDict(func_graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parse param form funcgraph failed.";
return ret;
}

ret = IfSaveTogether(&save_together_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "error occur when check condition of saving together.";
return ret;
}

if (save_together_) {
ret = SaveMindIRTogether();
} else {
ret = SplitSave();
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "save mindir weight failed.";
return ret;
}
return RET_OK;
}

int MindIRSerializer::SaveMindIRTogether() {
for (auto &param_proto : *(model_proto_.mutable_graph()->mutable_parameter())) {
std::string proto_name = param_proto.name();
auto para = GetFgParaAccordingToProtoName(proto_name);
if (para == nullptr) {
return RET_ERROR;
}
if (!para->has_default()) {
continue;
}
auto data = para->default_param()->cast<tensor::TensorPtr>();
param_proto.clear_raw_data();
param_proto.set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
}

return SaveProtoToFile(&model_proto_, save_model_path_);
}

int MindIRSerializer::CreateParameterDir() {
#ifdef _WIN32
dir_name_ = save_path_ + "\\" + model_name_ + "_variables";
#else
dir_name_ = save_path_ + "/" + model_name_ + "_variables";
#endif
fs_ = system::Env::GetFileSystem();
if (fs_ == nullptr) {
MS_LOG(ERROR) << "create file system failed.";
return RET_NULL_PTR;
}

if (fs_->FileExist(dir_name_)) {
if (!DeleteDirRecursively(dir_name_)) {
return RET_ERROR;
}
}

if (!fs_->CreateDir(dir_name_)) {
MS_LOG(ERROR) << "create dir failed.";
return RET_ERROR;
}

ChangeFileMode(dir_name_, S_IWUSR | S_IRUSR | S_IXUSR);
return RET_OK;
}

std::shared_ptr<Parameter> MindIRSerializer::GetFgParaAccordingToProtoName(const std::string &proto_name) {
auto beg_pos = proto_name.find_first_of(':') + 1;
if (beg_pos >= proto_name.size()) {
MS_LOG(ERROR) << "begin pos exceed proto name length.";
return nullptr;
}
auto name = proto_name.substr(beg_pos);
if (param_dict_.find(name) == param_dict_.end()) {
MS_LOG(ERROR) << "param proto name: " << name << " is not in param dict.";
return nullptr;
}
return param_dict_.at(name);
}

int MindIRSerializer::ChangeParaDataFile(const std::string &file) {
data_fs_->close();
delete data_fs_;
auto real_path = CreateExternalPath(file);
if (fs_->FileExist(real_path)) {
fs_->DeleteFile(real_path);
}
ChangeFileMode(real_path, S_IWUSR);
data_fs_ = OpenFile(real_path, std::ios::app);
char front_info[OFFSET]{0};
front_info[0] = IsSystemLittleEndidan();
data_fs_->write(front_info, OFFSET);
return RET_OK;
}

bool MindIRSerializer::IsSystemLittleEndidan() {
int check = 0x01;
auto address = reinterpret_cast<char *>(&check);
return *address == 0x01;
}

int MindIRSerializer::GetDataFile(const std::string &data_file_name, std::ofstream *fout, int64_t *parameter_size,
int64_t *offset) {
*offset = OFFSET;
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
if (fs == nullptr) {
MS_LOG(ERROR) << "create file system failed.";
return RET_NULL_PTR;
}
if (fs->FileExist(data_file_name)) {
ChangeFileMode(data_file_name, S_IWUSR);
}

std::byte place_holder[OFFSET];
fout = new std::ofstream;
fout->write(reinterpret_cast<const char *>(place_holder), *offset);

return RET_OK;
}

std::string MindIRSerializer::CreateExternalPath(const std::string &external_file) {
dir_path_ = RealPath(dir_name_.c_str());
std::string external_local_path{};
#ifdef _WIN32
external_local_path = dir_path_ + "\\" + external_file;
#else
external_local_path = dir_path_ + "/" + external_file;
#endif
return external_local_path;
}

int MindIRSerializer::SplitSave() {
MS_LOG(DEBUG) << "Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.";
int ret = CreateParameterDir();
if (ret != RET_OK) {
MS_LOG(ERROR) << "create parameter dir failed.";
return RET_ERROR;
}

int index = 0;
std::string external_local = model_name_ + "_data_" + std::to_string(index);
auto external_local_path = CreateExternalPath(external_local);
if (fs_->FileExist(external_local_path)) {
fs_->DeleteFile(external_local_path);
}
int64_t parameter_size = 0;
int64_t offset = OFFSET;

data_fs_ = OpenFile(external_local_path, std::ios::out | std::ios::binary | std::ios::trunc);
if (data_fs_ == nullptr) {
MS_LOG(ERROR) << "Open " << external_local_path << " failed";
return false;
}

for (auto &param_proto : *(model_proto_.mutable_graph()->mutable_parameter())) {
std::string proto_name = param_proto.name();
auto para = GetFgParaAccordingToProtoName(proto_name);
if (para == nullptr) {
return RET_ERROR;
}
if (!para->has_default()) {
continue;
}
auto data = para->default_param()->cast<tensor::TensorPtr>();
int64_t data_length = static_cast<int64_t>(data->data().nbytes());
int64_t append_size = 0;
if (data_length % OFFSET != 0) {
append_size = OFFSET - (data_length % OFFSET);
parameter_size += (append_size + data_length);
} else {
parameter_size += data_length;
}
if (parameter_size > static_cast<int64_t>(TOTAL_SAVE)) {
index++;
external_local = model_name_ + "data_" + std::to_string(index);
ret = ChangeParaDataFile(external_local);
if (ret != RET_OK) {
MS_LOG(ERROR) << "change parameter data file failed.";
return ret;
}
parameter_size = OFFSET;
}
*(param_proto.mutable_external_data()->mutable_location()) = external_local;
param_proto.mutable_external_data()->set_length(parameter_size);
param_proto.mutable_external_data()->set_offset(append_size);
data_fs_->write(static_cast<const char *>(data->data_c()), data_length);
auto append_data = new char[append_size];
if (append_data == nullptr) {
return RET_NULL_PTR;
}
offset += (data_length + append_size);
data_fs_->write(append_data, append_size);
delete[] append_data;
}

return SaveProtoToFile(&model_proto_, save_model_path_);
}

int MindIRSerializer::ParserPath(const std::string &output_path) {
if (!ParserPathAndModelName(output_path, &save_path_, &model_name_)) {
MS_LOG(ERROR) << "parser save path and model name from output_path failed.";
return RET_ERROR;
}
#ifdef _WIN32
save_model_path_ = save_path_ + "\\" + model_name_ + ".mindir";
#else
save_model_path_ = save_path_ + "/" + model_name_ + ".mindir";
#endif
return RET_OK;
}

int MindIRSerializer::ParamDict(const FuncGraphPtr &func_graph) {
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(func_graph, &all_func_graphs);
for (auto &fg : all_func_graphs) {
for (auto &para : fg->parameters()) {
if (!para->isa<Parameter>()) {
MS_LOG(ERROR) << "fg parameters contains non-parameter type node.";
return RET_ERROR;
}
auto para_node = para->cast<ParameterPtr>();
param_dict_[para->ToString()] = para_node;
}
}
return RET_OK;
}

int MindIRSerializer::IfSaveTogether(bool *save_together) {
size_t data_total = model_proto_.ByteSizeLong();
for (auto &param_proto : model_proto_.graph().parameter()) {
std::string proto_name = param_proto.name();
auto para = GetFgParaAccordingToProtoName(proto_name);
if (para == nullptr) {
return RET_ERROR;
}
if (!para->has_default()) {
continue;
}
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(para->default_param());
if (tensor == nullptr) {
MS_LOG(ERROR) << "param node default_param is not tensor.";
return RET_ERROR;
}
data_total += tensor->Size();
}
if (data_total > TOTAL_SAVE) {
*save_together = false;
} else {
*save_together = true;
}
return RET_OK;
}

int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
auto realpath = Common::CreatePrefixPath(output_file, true);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path of file " << output_file << " failed.";
return RET_ERROR;
}

ChangeFileMode(realpath.value(), S_IWUSR);
std::ofstream fout(realpath.value());
if (!fout.is_open()) {
MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
return RET_ERROR;
}

if (!model_proto->SerializeToOstream(&fout)) {
MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
fout.close();
return RET_ERROR;
}
fout.close();
ChangeFileMode(realpath.value(), S_IRUSR);
return RET_OK;
}

int MindIRSerialize(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph) {
if (!flag->export_mindir) {
return RET_OK;
}
#if defined(SYSTEM_ENV_WINDOWS)
MS_LOG(WARNING) << "mindir serialize not support windows now.";
return RET_NOT_SUPPORT;
#endif
mindspore::lite::MindIRSerializer serializer;
return serializer.Save(flag, func_graph);
}
} // namespace mindspore::lite

+ 78
- 0
mindspore/lite/tools/mindir_serializer/mindir_serializer.h View File

@@ -0,0 +1,78 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_
#define MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_

#include <string>
#include <memory>
#include <unordered_map>
#include <fstream>
#include <set>
#include "mindspore/core/ir/func_graph.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/converter_flags.h"
#include "proto/mind_ir.pb.h"
#include "mindspore/core/utils/system/env.h"

namespace mindspore::lite {
class MindIRSerializer {
public:
MindIRSerializer() = default;
virtual ~MindIRSerializer() {
if (data_fs_ != nullptr) {
data_fs_->close();
delete data_fs_;
data_fs_ = nullptr;
}
}
int Save(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph);

private:
int ParserPath(const std::string &output_path);
int IfSaveTogether(bool *save_together);
int SaveMindIRTogether();
int SplitSave();
int SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file);

private:
int ParamDict(const FuncGraphPtr &func_graph);
int CreateParameterDir();
std::shared_ptr<Parameter> GetFgParaAccordingToProtoName(const std::string &proto_name);
int ChangeParaDataFile(const std::string &file);
bool IsSystemLittleEndidan();
int GetDataFile(const std::string &data_file_name, std::ofstream *fout, int64_t *parameter_size, int64_t *offset);
std::string CreateExternalPath(const std::string &external_file);
int RemoveQuantParameterHolder(FuncGraphPtr func_graph);

private:
std::string model_name_;
std::string save_path_;
std::string save_model_path_;
std::string dir_name_;
std::string dir_path_;
bool save_together_ = true;
mind_ir::ModelProto model_proto_;
std::unordered_map<std::string, ParameterPtr> param_dict_{};
std::unordered_map<tensor::TensorPtr, mind_ir::TensorProto *> para_proto_dict_{};
std::fstream *data_fs_ = nullptr;
std::shared_ptr<system::FileSystem> fs_{};
};

// export func_graph
int MindIRSerialize(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_

Loading…
Cancel
Save