diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 561f5dbee8..9045581c66 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -125,7 +125,8 @@ if (ENABLE_DUMP_PROTO) "utils/lineage.proto" "utils/checkpoint.proto" "utils/print.proto" - "utils/node_strategy.proto" + "utils/node_strategy.proto" + "utils/mind_ir.proto" ) ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) @@ -156,7 +157,7 @@ endif() ## make sub objects set(SUB_COMP transform/graph_ir - transform/onnx + transform/express_ir backend/optimizer backend/kernel_compiler backend/session @@ -344,13 +345,13 @@ if (ENABLE_MINDDATA) endif () # build inference -set(LOAD_ONNX_SRC - ${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_converter.cc - ${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc +set(LOAD_MINDIR_SRC + ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc + ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc ) add_library(inference SHARED ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc - ${LOAD_ONNX_SRC} + ${LOAD_MINDIR_SRC} ) set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) diff --git a/mindspore/ccsrc/backend/session/infer_session.cc b/mindspore/ccsrc/backend/session/infer_session.cc index d82a9f3008..5e897f8606 100644 --- a/mindspore/ccsrc/backend/session/infer_session.cc +++ b/mindspore/ccsrc/backend/session/infer_session.cc @@ -20,11 +20,11 @@ #include #include "include/inference.h" -#include "utils/load_onnx/anf_converter.h" #include "backend/session/session_basic.h" #include "backend/session/session_factory.h" #include "backend/session/executor_manager.h" #include "base/base_ref_utils.h" +#include "load_mindir/load_model.h" #include "backend/kernel_compiler/oplib/oplib.h" #include "utils/context/context_extends.h" #include "runtime/device/kernel_runtime_manager.h" @@ -58,46 +58,9 @@ std::shared_ptr InferSession::CreateSession(const std::string &dev MSInferSession::MSInferSession() = default; MSInferSession::~MSInferSession() = default; -std::shared_ptr> MSInferSession::ReadFile(const std::string &file) { - if (file.empty()) { - MS_LOG(ERROR) << "file is nullptr"; - return nullptr; - } - std::string realPath = file; - std::ifstream ifs(realPath); - if (!ifs.good()) { - MS_LOG(ERROR) << "file: " << realPath << " is not exist"; - return nullptr; - } - - if (!ifs.is_open()) { - MS_LOG(ERROR) << "file: " << realPath << "open failed"; - return nullptr; - } - - ifs.seekg(0, std::ios::end); - size_t size = ifs.tellg(); - std::shared_ptr> buf(new (std::nothrow) std::vector(size)); - if (buf == nullptr) { - MS_LOG(ERROR) << "malloc buf failed, file: " << realPath; - ifs.close(); - return nullptr; - } - - ifs.seekg(0, std::ios::beg); - ifs.read(buf->data(), size); - ifs.close(); - - return buf; -} - Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { - auto graphBuf = ReadFile(file_name); - if (graphBuf == nullptr) { - MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); - return FAILED; - } - auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_); + Py_Initialize(); + auto graph = RunLoadMindIR(file_name); if (graph == nullptr) { MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); return FAILED; @@ -213,6 +176,7 @@ Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &reques } inputs.push_back(input); } + auto ret = CheckModelInputs(model_id, inputs); if (ret != SUCCESS) { MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed"; @@ -250,16 +214,6 @@ Status MSInferSession::FinalizeEnv() { return SUCCESS; } -std::shared_ptr MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) { - try { - auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); - return anf_graph; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference LoadModel failed"; - return nullptr; - } -} - void MSInferSession::RegAllOp() { static std::mutex init_mutex; static bool Initialized = false; diff --git a/mindspore/ccsrc/backend/session/infer_session.h b/mindspore/ccsrc/backend/session/infer_session.h index 2a50619ae8..1c8fd2ded1 100644 --- a/mindspore/ccsrc/backend/session/infer_session.h +++ b/mindspore/ccsrc/backend/session/infer_session.h @@ -54,8 +54,6 @@ class MSInferSession : public InferSession { rtContext_t context_ = nullptr; #endif - std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device); - std::shared_ptr> ReadFile(const std::string &file); static void RegAllOp(); string AjustTargetName(const std::string &device); Status CompileGraph(std::shared_ptr funcGraphPtr, uint32_t &model_id); diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index 2f4954c89b..c74b5cf038 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -1,7 +1,7 @@ # build mindspore_shared_lib -set(LOAD_ONNX_SRC - ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_converter.cc - ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +set(LOAD_MINDIR_SRC + ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc + ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc ) file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") @@ -18,7 +18,7 @@ set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc ${API_MS_INFER_SRC} ${API_ACL_SRC} ${API_OPS_SRC} - ${LOAD_ONNX_SRC}) + ${LOAD_MINDIR_SRC}) add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc index 3fde01a4d6..ae50686254 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc @@ -17,9 +17,9 @@ #include "cxx_api/model/acl/model_converter.h" #include #include "pybind11/pybind11.h" -#include "utils/load_onnx/anf_converter.h" #include "transform/graph_ir/convert.h" #include "transform/graph_ir/graph_runner.h" +#include "core/load_mindir/load_model.h" #include "mindspore/core/utils/ms_context.h" #include "backend/kernel_compiler/oplib/oplib.h" @@ -79,8 +79,7 @@ bool CreateSessionAndGraphRunner() { std::shared_ptr ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) { try { - auto anf_graph = - lite::AnfConverter::RunAnfConverter(reinterpret_cast(model_data.Data()), model_data.DataSize()); + auto anf_graph = ConvertStreamToFuncGraph(reinterpret_cast(model_data.Data()), model_data.DataSize()); return anf_graph; } catch (std::exception &e) { MS_LOG(ERROR) << "Load MindIR failed."; @@ -364,6 +363,7 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { RegAllOp(); + Py_Initialize(); auto func_graph = ConvertMindIrToFuncGraph(model_data); if (func_graph == nullptr) { MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index e348415380..d136f3e962 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -19,7 +19,7 @@ #include #include -#include "utils/load_onnx/anf_converter.h" +#include "load_mindir/load_model.h" #include "backend/session/session_basic.h" #include "backend/session/session_factory.h" #include "backend/session/executor_manager.h" @@ -117,9 +117,9 @@ Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::m return FAILED; } std::shared_ptr anf_graph; + Py_Initialize(); try { - anf_graph = - lite::AnfConverter::RunAnfConverter(static_cast(model_data.Data()), model_data.DataSize()); + anf_graph = ConvertStreamToFuncGraph(static_cast(model_data.Data()), model_data.DataSize()); } catch (std::exception &e) { MS_LOG(ERROR) << "Inference LoadModel failed"; return FAILED; @@ -290,9 +290,10 @@ Status MsModel::FinalizeEnv() { } std::shared_ptr MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) { + Py_Initialize(); MS_EXCEPTION_IF_NULL(model_buf); try { - auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); + auto anf_graph = ConvertStreamToFuncGraph(model_buf, size); return anf_graph; } catch (std::exception &e) { MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what(); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index a6b677a50a..fd93c32a67 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -74,6 +74,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") + .def("convert_funcgraph_to_mindir", &ExecutorPy::ConvertFuncGraphToMindIR, py::arg("graph"), + "Convert FuncGraph to MindIR proto.") .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") .def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"), @@ -108,6 +110,7 @@ PYBIND11_MODULE(_c_expression, m) { (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); + (py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load MindIR as Graph."); (void)py::class_>(m, "MpiConfig") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index df64ea6c67..07746912f8 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -45,6 +45,7 @@ #include "debug/draw.h" #include "pipeline/pynative/pynative_execute.h" #include "frontend/optimizer/py_pass_manager.h" +#include "load_mindir/load_model.h" #include "pybind_api/pybind_patch.h" #include "utils/shape_utils.h" #include "utils/info.h" @@ -103,6 +104,16 @@ void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) { } } // namespace +py::bytes ExecutorPy::ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr) { + std::string proto_str = GetBinaryProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; +} + +FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::RunLoadMindIR(file_name); } + py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); abstract::AbstractBasePtrList args_spec; diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 2e7502ec90..ff14b47725 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -82,6 +82,7 @@ class ExecutorPy : public std::enable_shared_from_this { ResourcePtr GetResource(const std::string &phase); FuncGraphPtr GetFuncGraph(const std::string &phase); py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); + py::bytes ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr); compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); bool HasCompiled(const std::string &phase) const; @@ -138,6 +139,7 @@ void ClearResAtexit(); void ReleaseGeTsd(); void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); +FuncGraphPtr LoadMindIR(const std::string &file_name); // init and exec dataset sub graph bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, diff --git a/mindspore/ccsrc/transform/express_ir/CMakeLists.txt b/mindspore/ccsrc/transform/express_ir/CMakeLists.txt new file mode 100644 index 0000000000..b871857b4b --- /dev/null +++ b/mindspore/ccsrc/transform/express_ir/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) +add_library(_mindspore_transform_express_ir_obj OBJECT ${_EXPORTER_IR_SRC_FILES}) \ No newline at end of file diff --git a/mindspore/ccsrc/transform/onnx/ir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc similarity index 62% rename from mindspore/ccsrc/transform/onnx/ir_exporter.cc rename to mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index a6b7cbe38b..96d184f02f 100644 --- a/mindspore/ccsrc/transform/onnx/ir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -25,33 +25,40 @@ #include "ir/param_info.h" #include "ir/func_graph.h" #include "base/core_ops.h" -#include "proto/onnx.pb.h" +#include "proto/mind_ir.pb.h" namespace mindspore { using FloatPtr = std::shared_ptr; using IntPtr = std::shared_ptr; -// anf type to onnx type map -static std::unordered_map g_data_type_map = { - {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, - {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, - {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, - {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, - {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, - {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, - {kObjectTypeString, onnx::TensorProto_DataType_STRING}, +// anf type to mindir type map +static std::unordered_map g_data_type_map = { + {kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL}, + {kNumberTypeInt8, mind_ir::TensorProto_DataType_INT8}, + {kNumberTypeInt16, mind_ir::TensorProto_DataType_INT16}, + {kNumberTypeInt32, mind_ir::TensorProto_DataType_INT32}, + {kNumberTypeInt64, mind_ir::TensorProto_DataType_INT64}, + {kNumberTypeUInt8, mind_ir::TensorProto_DataType_UINT8}, + {kNumberTypeUInt16, mind_ir::TensorProto_DataType_UINT16}, + {kNumberTypeUInt32, mind_ir::TensorProto_DataType_UINT32}, + {kNumberTypeUInt64, mind_ir::TensorProto_DataType_UINT64}, + {kNumberTypeFloat16, mind_ir::TensorProto_DataType_FLOAT16}, + {kNumberTypeFloat32, mind_ir::TensorProto_DataType_FLOAT}, + {kNumberTypeFloat64, mind_ir::TensorProto_DataType_DOUBLE}, + {kObjectTypeString, mind_ir::TensorProto_DataType_STRING}, }; -static std::unordered_map g_data_bits_int_map = { - {8, onnx::TensorProto_DataType_INT8}, - {16, onnx::TensorProto_DataType_INT16}, - {32, onnx::TensorProto_DataType_INT32}, - {64, onnx::TensorProto_DataType_INT64}, +static std::unordered_map g_data_bits_int_map = { + {8, mind_ir::TensorProto_DataType_INT8}, + {16, mind_ir::TensorProto_DataType_INT16}, + {32, mind_ir::TensorProto_DataType_INT32}, + {64, mind_ir::TensorProto_DataType_INT64}, }; -static std::unordered_map g_data_bits_float_map = { - {16, onnx::TensorProto_DataType_FLOAT16}, - {32, onnx::TensorProto_DataType_FLOAT}, +static std::unordered_map g_data_bits_float_map = { + {16, mind_ir::TensorProto_DataType_FLOAT16}, + {32, mind_ir::TensorProto_DataType_FLOAT}, + {64, mind_ir::TensorProto_DataType_FLOAT64}, }; // Can build different builder according to format @@ -77,34 +84,34 @@ class IrExportBuilder { void BuildModel(const FuncGraphPtr &func_graph); private: - void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto); - void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto); - std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto); - - void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto); - void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto); - void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); - void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); - void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, + void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); + void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); + void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); + void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); + void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); + std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto); + + void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); + void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto); + void SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); + void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); + void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); + void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); - void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); - void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, + void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); - void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, + void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); - onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); - onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits); + mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id); + mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits); + mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits); std::string GetNodeName(const AnfNodePtr &node); std::string GetUniqueNodeName(const AnfNodePtr &node); std::string GetOpTypeName(const AnfNodePtr &node); @@ -114,8 +121,8 @@ class IrExportBuilder { void ResetTupleIndex() { shape_index_ = 0; } private: - onnx::ModelProto model_; - onnx::NodeProto *last_node_{nullptr}; + mind_ir::ModelProto model_; + mind_ir::NodeProto *last_node_{nullptr}; std::list todo_; std::map node_index_map_; size_t node_index_{0}; @@ -144,13 +151,13 @@ std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { } void IrExportBuilder::BuildModelInfo() { - model_.set_ir_version(onnx::IR_VERSION_2019_1_22); + model_.set_ir_version("0.1.0"); model_.set_producer_name("MindSpore"); - model_.set_model_version(1); + model_.set_model_version("1.1.0"); } void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { - onnx::GraphProto *graph_proto = model_.mutable_graph(); + mind_ir::GraphProto *graph_proto = model_.mutable_graph(); graph_proto->set_name(func_graph->ToString()); ResetNodeIndex(); todo_.clear(); @@ -162,7 +169,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { } } -void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { +void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { // Export parameters // 1. parameters should be mapped to ValueInfoProto // 2. parameters with default value should be mapped to Initializer @@ -172,33 +179,31 @@ void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::Graph BuildNodes(func_graph, graph_proto); } -void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { +void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { for (auto &item : func_graph->parameters()) { auto param = item->cast(); if (param == nullptr) { MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; } - onnx::ValueInfoProto *input_proto = graph_proto->add_input(); std::string param_name = GetUniqueNodeName(param); - input_proto->set_name(param_name); - SetValueInfoProto(param, input_proto); - if (!param->has_default()) { + if (param->has_default()) { MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; - continue; - } - - // Using ONNX initializer to set parameter's default value - onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); - initializer_proto->set_name(param_name); - SetParamToTensorProto(param, initializer_proto); - auto tensor = std::dynamic_pointer_cast(param->default_param()); - if (tensor) { - initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); + mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter(); + parameter_proto->set_name(param_name); + SetParamToTensorProto(param, parameter_proto); + auto tensor = std::dynamic_pointer_cast(param->default_param()); + if (tensor) { + parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); + } + } else { + mind_ir::ValueInfoProto *input_proto = graph_proto->add_input(); + input_proto->set_name(param_name); + SetValueInfoProto(param, input_proto); } } } -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { +mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) { auto iter = g_data_type_map.find(type_id); if (iter == g_data_type_map.end()) { MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; @@ -206,7 +211,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { return iter->second; } -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { +mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) { auto iter = g_data_bits_int_map.find(bits); if (iter == g_data_bits_int_map.end()) { MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; @@ -214,7 +219,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { return iter->second; } -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { +mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) { auto iter = g_data_bits_float_map.find(bits); if (iter == g_data_bits_float_map.end()) { MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; @@ -222,73 +227,70 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { return iter->second; } -void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) { +void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) { if (node == nullptr || value_proto == nullptr) { MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; } MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); - SetValueInfoProto(node->Type(), node->Shape(), value_proto); -} - -void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::ValueInfoProto *const value_proto) { - onnx::TypeProto *type_proto = value_proto->mutable_type(); + const TypePtr &type = node->Type(); + const BaseShapePtr &shape = node->Shape(); if (type->isa() && shape->isa()) { auto tensor = type->cast(); auto elem_type = tensor->element(); const auto &dims = shape->cast()->shape(); - type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); + mind_ir::TensorProto *tensor_proto = value_proto->add_tensor(); + tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id())); if (dims.size() == 0) { MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + tensor_proto->add_dims(1); } else { for (const auto &dim : dims) { MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + tensor_proto->add_dims(dim); } } } else if (type->isa()) { auto tup_shape = shape->cast(); - type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); + value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); } else if (type->isa() || type->isa()) { - type_proto->set_denotation(type->type_name()); + value_proto->set_denotation(type->type_name()); } else { MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; } } -void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { +void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } attr_proto->set_ref_attr_name("tensor:value0"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); + mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); tensor_proto->set_name("value0"); auto data = value->cast(); tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); auto dtype = data->data_type(); auto shape = data->shape_c(); - tensor_proto->set_data_type(GetOnnxDataType(dtype)); + tensor_proto->set_data_type(GetMindirDataType(dtype)); for (const auto &dim : shape) { tensor_proto->add_dims(dim); } } void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::TensorProto *const tensor_proto) { + mind_ir::TensorProto *const tensor_proto) { if (!type->isa() || !shape->isa()) { MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); } auto tensor = type->cast(); const auto &dims = shape->cast()->shape(); - tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id())); + tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id())); for (const auto &dim : dims) { tensor_proto->add_dims(dim); } } -void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { +void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { if (param == nullptr || tensor_proto == nullptr) { MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; } @@ -296,7 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten SetTensorProto(param->Type(), param->Shape(), tensor_proto); } -void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { +void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); bool is_only_return = true; for (const AnfNodePtr &node : nodes) { @@ -317,12 +319,12 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt } } -void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) { +void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { if (node->size() != 2) { MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; } AnfNodePtr arg = node->input(1); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + mind_ir::ValueInfoProto *output_proto = graph_proto->add_output(); std::string output_name = GetUniqueNodeName(node); output_proto->set_name(output_name); last_node_->set_output(0, output_name); @@ -349,7 +351,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { } void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::AttributeProto *const attr_proto, std::string *const seq_string) { + mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { if (type->isa() && seq_string != nullptr) { *seq_string += "Tuple["; auto elements = type->cast()->elements(); @@ -361,7 +363,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt } else if (type->isa() && shape->isa() && seq_string != nullptr) { string shape_name = "shape" + std::to_string(GetTupleIndex()); *seq_string += shape_name + ","; - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); tensor_proto->set_name(shape_name); SetTensorProto(type, shape, tensor_proto); } else if ((type->isa() || type->isa()) && seq_string != nullptr) { @@ -371,7 +373,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt } } -void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { +void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) { // Get shape of cnode // 1. need to get shape from tuple element // 2. save shape in TensorProto @@ -381,13 +383,13 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto auto shape = node->Shape(); ResetTupleIndex(); std::string seq_string = "shape:"; - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); SetShapeToNodeProto(type, shape, attr_proto, &seq_string); attr_proto->set_ref_attr_name(seq_string); MS_LOG(DEBUG) << "CNode shape: " << seq_string; } -void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { +void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { auto inputs_size = node->size(); if (inputs_size < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; @@ -403,7 +405,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g } // Build cnode - onnx::NodeProto *node_proto = graph_proto->add_node(); + mind_ir::NodeProto *node_proto = graph_proto->add_node(); std::string output_name = GetUniqueNodeName(node); node_proto->add_output(output_name); node_proto->set_name(output_name); @@ -421,7 +423,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g auto prim = GetValueNode(op); for (auto attr : prim->attrs()) { MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name(attr.first); SetValueToAttributeProto(attr.second, attr_proto); } @@ -430,11 +432,11 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g } } -std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) { +std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) { std::string node_name = GetUniqueNodeName(node); if (node->isa()) { // When node input is a ValueNode, need to create a Constant Node - onnx::NodeProto *node_proto = graph_proto->add_node(); + mind_ir::NodeProto *node_proto = graph_proto->add_node(); node_proto->add_output(node_name); SetAttributeProto(node, node_proto); } @@ -478,44 +480,48 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { return node_name; } -void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) { +void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) { if (node == nullptr || node_proto == nullptr) { MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; } auto value = node->cast()->value(); node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); SetValueToAttributeProto(value, attr_proto); } -void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { +void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); + mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); if (value->isa()) { attr_proto->set_ref_attr_name("type:value0"); tensor_proto->set_name("value0"); auto int_value = value->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); } else if (value->isa()) { attr_proto->set_ref_attr_name("type:value0"); tensor_proto->set_name("value0"); auto float_value = value->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); + } else if (value->isa()) { + attr_proto->set_ref_attr_name("type:value0"); + tensor_proto->set_name("value0"); + tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL); } else if (value->isa()) { attr_proto->set_ref_attr_name("type:tensor0"); tensor_proto->set_name("tensor0"); auto elem_type = value->cast()->element(); if (elem_type->isa()) { auto int_value = elem_type->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); } else if (elem_type->isa()) { auto float_value = elem_type->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); } else { MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); } @@ -524,18 +530,18 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri } } -void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { +void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } if (value->isa() || value->isa()) { - SetScalarToAttributeProto(value, attr_proto); + SetScalarToAttributeProto_ir(value, attr_proto); } else if (value->isa() || value->isa()) { SetTypeToAttributeProto(value, attr_proto); - } else if (value->isa() || value->isa()) { + } else if (value->isa()) { ResetTupleIndex(); std::string seq_string = "scalar:"; - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); SetSequenceToAttributeProto(value->cast(), attr_proto, &seq_string); attr_proto->set_ref_attr_name(seq_string); MS_LOG(DEBUG) << "Attr string: " << seq_string; @@ -549,74 +555,102 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr } } -void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } +void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { attr_proto->set_ref_attr_name("scalar:value0"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); - SetScalarToProto(value, tensor_proto, "value0"); + if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING); + attr_proto->set_s(GetValue(value)); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); + attr_proto->set_i(GetValue(value)); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64); + attr_proto->set_i(value->cast()->value()); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); + attr_proto->set_f(GetValue(value)); + } else if (value->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); + attr_proto->set_d(GetValue(value)); + } else { + MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); + } } -void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, - const std::string &value_name) { - if (value == nullptr || tensor_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; - } - tensor_proto->set_name(value_name); +void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); - tensor_proto->add_string_data(GetValue(value)); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING); + attr_proto->add_strings(GetValue(value)); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); - tensor_proto->add_int32_data(GetValue(value)); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); + attr_proto->add_ints(GetValue(value)); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); - tensor_proto->add_int32_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); - tensor_proto->add_int32_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); - tensor_proto->add_int32_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - tensor_proto->add_int64_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); - tensor_proto->add_int32_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); - tensor_proto->add_int32_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); - tensor_proto->add_uint64_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); - tensor_proto->add_uint64_data(value->cast()->value()); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64); + attr_proto->add_ints(value->cast()->value()); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); - tensor_proto->add_float_data(GetValue(value)); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); + attr_proto->add_floats(GetValue(value)); } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); - tensor_proto->add_double_data(GetValue(value)); + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); + attr_proto->add_doubles(GetValue(value)); } else { MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); } } -void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, +void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { string value_name = "value" + std::to_string(GetTupleIndex()); if (seq_string != nullptr) { *seq_string += value_name + ","; } - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); - SetScalarToProto(value, tensor_proto, value_name); + SetScalarToAttributeProto_irs(value, attr_proto); } -void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, +void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, + mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; @@ -625,6 +659,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, *seq_string += "Tuple["; const ValueTuplePtr &tuple_value = value->cast(); if (tuple_value->value().size() == 0) { + *seq_string += "],"; MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; return; } @@ -640,6 +675,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, *seq_string += "List["; const ValueListPtr &list_value = value->cast(); if (list_value->value().size() == 0) { + *seq_string += "],"; MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; return; } diff --git a/mindspore/ccsrc/transform/onnx/onnx_exporter.cc b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc similarity index 100% rename from mindspore/ccsrc/transform/onnx/onnx_exporter.cc rename to mindspore/ccsrc/transform/express_ir/onnx_exporter.cc diff --git a/mindspore/ccsrc/transform/onnx/CMakeLists.txt b/mindspore/ccsrc/transform/onnx/CMakeLists.txt deleted file mode 100644 index 0d2f6c947b..0000000000 --- a/mindspore/ccsrc/transform/onnx/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) -add_library(_mindspore_transform_onnx_obj OBJECT ${_ONNX_SRC_FILES}) diff --git a/mindspore/ccsrc/utils/CMakeLists.txt b/mindspore/ccsrc/utils/CMakeLists.txt index 72f698a97e..71d68729b9 100644 --- a/mindspore/ccsrc/utils/CMakeLists.txt +++ b/mindspore/ccsrc/utils/CMakeLists.txt @@ -5,11 +5,5 @@ if (NOT ENABLE_GE) list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_GE_SRC_FILES}) endif () -file(GLOB_RECURSE _UTILS_LITE_SRC_FILES - ./load_onnx/anf_converter.cc - ./load_onnx/anf_model_parser.cc - ) -list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_LITE_SRC_FILES}) - set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS) add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST}) diff --git a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc b/mindspore/ccsrc/utils/load_onnx/anf_converter.cc deleted file mode 100644 index c0e6a13f0f..0000000000 --- a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "utils/load_onnx/anf_converter.h" - -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" - -#include "utils/load_onnx/anf_model_parser.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "proto/onnx.pb.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace lite { -const char WHITESPACE[] = "\t\n\v\f\r "; -const int FLAG_PREFIX_LEN = 2; - -void AnfConverter::Trim(std::string *input) { - if (input == nullptr) { - return; - } - if (input->empty()) { - return; - } - input->erase(0, input->find_first_not_of(WHITESPACE)); - input->erase(input->find_last_not_of(WHITESPACE) + 1); -} - -int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string fileType) { - if (modelFile.size() > fileType.size()) { - if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) { - return 0; - } else { - return 1; - } - } else { - return 1; - } -} - -bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { - std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); - - if (modelFile.size() > PATH_MAX) { - MS_LOG(DEBUG) << "file path " << modelFile << " is too long."; - return false; - } - char real_path[PATH_MAX + 1] = {0}; -#if defined(_WIN32) || defined(_WIN64) - if (nullptr == _fullpath(real_path, modelFile.c_str(), PATH_MAX)) { - MS_LOG(DEBUG) << modelFile << " does not exit."; - return false; - } -#else - if (nullptr == realpath(modelFile.c_str(), real_path)) { - MS_LOG(DEBUG) << modelFile << " does not exit."; - return false; - } -#endif - int fd = open(real_path, O_RDONLY); - if (fd < 0) { - MS_LOG(EXCEPTION) << "failed to open file"; - } - google::protobuf::io::FileInputStream input(fd); - google::protobuf::io::CodedInputStream code_input(&input); - code_input.SetTotalBytesLimit(INT_MAX, 536870912); - bool ret = onnx_model->ParseFromCodedStream(&code_input); - if (!ret) { - MS_LOG(ERROR) << "load onnx file failed"; - return false; - } - (void)close(fd); - MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; - return true; -} - -std::shared_ptr AnfConverter::RunAnfConverter(const std::string &file_path) { - std::string modelFile; - - std::string tmp = file_path; - Trim(&tmp); - const std::string flagItem(tmp); - - size_t pos = flagItem.find_first_of("="); - if (pos == std::string::npos) { - MS_LOG(ERROR) << "Trans data not support input format!"; - } else { - modelFile = flagItem.substr(pos + 1); - std::cout << "input protobuf file path is: " << modelFile << std::endl; - } - - if (ValidateFileStr(modelFile, ".pb") != 0) { - MS_LOG(EXCEPTION) << "INPUT ILLEGAL: modelFile must be *.pb"; - } - - onnx::ModelProto model_; - ReadOnnxFromBinary(modelFile, &model_); - MSANFModelParser model_parser; - FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); - return dstgraph_ptr; -} - -std::shared_ptr AnfConverter::RunAnfConverter(const char *buf, const size_t buf_size) { - Py_Initialize(); - MS_EXCEPTION_IF_NULL(buf); - std::string str((const char *)buf, buf_size); - onnx::ModelProto model_; - if (!model_.ParseFromString(str)) { - MS_LOG(EXCEPTION) << "Parse model from buffer fail!"; - } - MSANFModelParser model_parser; - FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); - return dstgraph_ptr; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc deleted file mode 100644 index 22892a7580..0000000000 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +++ /dev/null @@ -1,693 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/load_onnx/anf_model_parser.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "ir/tensor.h" -#include "ir/param_info.h" -#include "frontend/operator/ops.h" -#include "abstract/abstract_value.h" -#include "proto/onnx.pb.h" -#include "utils/log_adapter.h" -#include "utils/shape_utils.h" - -using std::string; - -namespace mindspore { -namespace lite { -static constexpr char kConstantValueNode[] = "Constant"; -static constexpr char kCNodeShapeAttr[] = "shape"; -static constexpr char kCNodeShape1Attr[] = "shape1"; -static constexpr char kCNodeShape2Attr[] = "shape2"; -enum ParseForm : int { - FORM_PARSE_TYPE = 0, - FORM_PARSE_SCALAR = 1, - FORM_PARSE_TENSOR = 2, -}; - -static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; - -static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, -}; - -template -std::shared_ptr ParserAttr(const std::string &str, const std::unordered_map &kv) { - std::stack rules; - std::stack

value; - int count = 0; - for (size_t i = 0; i < str.length(); i++) { - if (str[i] == '[') { - rules.push("["); - } else if (str[i] == ']') { - // rules - std::vector

vec; - while (rules.top() != "[") { - rules.pop(); - vec.push_back(value.top()); - value.pop(); - } - // pop "[" - rules.pop(); - // make tuple for names - std::string res = "dummy"; - // make tuple for values - reverse(vec.begin(), vec.end()); - auto vt = std::make_shared(vec); - if (rules.empty() && value.empty()) { - return vt; - } - rules.push(res); - value.push(vt); - } else if (str[i] == ',') { - continue; - } else { - count++; - if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { - auto value_name = str.substr(i - count + 1, count); - value.push(kv.at(value_name)); - rules.push(value_name); - count = 0; - } - } - } - return {}; -} - -std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, - const std::unordered_map &kv) { - std::string str = attr_name; - auto replace = [&](const string &orgStr, const string &newStr) { - std::string::size_type pos(0); - while ((pos = str.find(orgStr)) != std::string::npos) { - str.replace(pos, orgStr.length(), newStr); - } - return str; - }; - // remove "scalar:" - str = replace("scalar:", ""); - // remove "Tuple" - str = replace("Tuple", ""); - // remove "List" - str = replace("List", ""); - auto result = ParserAttr(str, kv); - if (!result) { - return {}; - } - return result; -} - -std::shared_ptr ParserAttrShape( - const std::string &attr_name, const std::unordered_map &kv) { - std::string str = attr_name; - auto replace = [&](const string &orgStr, const string &newStr) { - std::string::size_type pos(0); - while ((pos = str.find(orgStr)) != std::string::npos) { - str.replace(pos, orgStr.length(), newStr); - } - return str; - }; - // remove "scalar:" - str = replace("shape:", ""); - // remove "Tuple" - str = replace("Tuple", ""); - // remove "List" - str = replace("List", ""); - - auto result = ParserAttr(str, kv); - if (!result) { - return {}; - } - return result; -} - -#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ - ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ - auto value = static_cast(attr_tensor.type##_data(0)); \ - return MakeValue(value); \ - } - -PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) -PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) -PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) -PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) -PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) -PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) -PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) - -bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { - MS_EXCEPTION_IF_NULL(node); - if (!value_proto.has_type() || !value_proto.has_name()) { - MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; - return false; - } - node->set_name(value_proto.name()); - const auto &type_proto = value_proto.type(); - if (!type_proto.has_tensor_type()) { - MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! "; - return false; - } - const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); - if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { - MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; - return false; - } - const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); - ShapeVector shape; - for (int i = 0; i < tensor_shape.dim_size(); ++i) { - shape.push_back(tensor_shape.dim(i).dim_value()); - } - - if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; - return false; - } - - tensor::TensorPtr tensor_info = - std::make_shared(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); - MS_EXCEPTION_IF_NULL(tensor_info); - auto tensor_abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(tensor_abstract); - node->set_abstract(tensor_abstract); - - if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { - const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; - std::string initial_data = initialize_proto.raw_data(); - auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); - MS_EXCEPTION_IF_NULL(tensor_data_buf); - auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size()); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; - } - - node->set_default_param(tensor_info); - } - anfnode_build_map_[value_proto.name()] = node; - return true; -} - -bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto) { - MS_EXCEPTION_IF_NULL(outputFuncGraph); - MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); - - for (int i = 0; i < importProto.initializer_size(); ++i) { - const onnx::TensorProto &initializer_proto = importProto.initializer(i); - if (!initializer_proto.has_name()) { - MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; - return false; - } - default_para_map_[initializer_proto.name()] = initializer_proto; - } - - MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); - for (int i = 0; i < importProto.input_size(); ++i) { - const onnx::ValueInfoProto &input_proto = importProto.input(i); - if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { - MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; - return false; - } - } - return true; -} - -bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { - MS_EXCEPTION_IF_NULL(prim); - const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; - return false; - } - prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); - return true; -} - -ValuePtr MSANFModelParser::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - switch (attr_tensor_type) { - case onnx::TensorProto_DataType_STRING: { - return ParseAttrInScalar_string_string(attr_tensor); - } - case onnx::TensorProto_DataType_INT32: { - return ParseAttrInScalar_int32_int32(attr_tensor); - } - case onnx::TensorProto_DataType_INT64: { - return ParseAttrInScalar_int64_int64(attr_tensor); - } - case onnx::TensorProto_DataType_UINT64: { - return ParseAttrInScalar_uint64_uint64(attr_tensor); - } - case onnx::TensorProto_DataType_FLOAT: { - return ParseAttrInScalar_float_float(attr_tensor); - } - case onnx::TensorProto_DataType_DOUBLE: { - return ParseAttrInScalar_double_double(attr_tensor); - } - case onnx::TensorProto_DataType_BOOL: { - return ParseAttrInScalar_int32_bool(attr_tensor); - } - default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; - return {}; - } - return {}; -} - -bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { - MS_EXCEPTION_IF_NULL(prim); - MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; - return false; -} - -bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { - MS_EXCEPTION_IF_NULL(prim); - const std::string &attr_name = attr_proto.name(); - if (!attr_proto.has_ref_attr_name()) { - MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; - return false; - } - const std::string &ref_attr_name = attr_proto.ref_attr_name(); - string type; - std::size_t pos(0); - if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("scalar:").length() - 1); - } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("type:").length() - 1); - } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("tensor:").length() - 1); - } - std::unordered_map kv; - for (int i = 0; i < attr_proto.tensors_size(); i++) { - const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); - switch (kParseTypeSwitchMap[type]) { - case FORM_PARSE_TYPE: { - ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); - break; - } - case FORM_PARSE_SCALAR: { - auto res = ObtainCNodeAttrInScalarForm(attr_tensor); - kv.insert(std::pair(attr_tensor.name(), res)); - break; - } - case FORM_PARSE_TENSOR: { - ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); - break; - } - default: - MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; - return false; - } - } - - if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { - if (kv.size() == 1) { - auto iter = kv.begin(); - prim->AddAttr(attr_name, iter->second); - } else { - auto res = ParserScalarAttrValue(ref_attr_name, kv); - prim->AddAttr(attr_name, res); - } - } - return true; -} -bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - ShapeVector shape; - for (int i = 0; i < attr_tensor.dims_size(); ++i) { - shape.push_back(attr_tensor.dims(i)); - } - tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); - const std::string &tensor_buf = attr_tensor.raw_data(); - auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); - auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; - } - - auto new_value_node = NewValueNode(MakeValue(tensor_info)); - MS_EXCEPTION_IF_NULL(new_value_node); - auto tensor_abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(tensor_abstract); - new_value_node->set_abstract(tensor_abstract); - anfnode_build_map_[value_node_name] = new_value_node; - return true; -} - -bool MSANFModelParser::ObtainValueNodeInScalarForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - ValuePtr value_ptr = nullptr; - switch (attr_tensor_type) { - case onnx::TensorProto_DataType_INT32: { - std::vector add_data; - for (int i = 0; i < attr_tensor.int32_data_size(); ++i) { - add_data.push_back(attr_tensor.int32_data(i)); - } - if (add_data.size() == 1) { - value_ptr = MakeValue(add_data[0]); - } else if (!add_data.empty()) { - value_ptr = MakeValue>(add_data); - } - break; - } - case onnx::TensorProto_DataType_FLOAT: { - std::vector add_data; - for (int i = 0; i < attr_tensor.float_data_size(); ++i) { - add_data.push_back(attr_tensor.float_data(i)); - } - - if (add_data.size() == 1) { - value_ptr = MakeValue(add_data[0]); - } else if (!add_data.empty()) { - value_ptr = MakeValue>(add_data); - } - break; - } - case onnx::TensorProto_DataType_UNDEFINED: { - std::vector elems; - value_ptr = std::make_shared(elems); - break; - } - default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; - return false; - } - auto new_value_node = NewValueNode(value_ptr); - MS_EXCEPTION_IF_NULL(new_value_node); - new_value_node->set_abstract(value_ptr->ToAbstract()); - anfnode_build_map_[value_node_name] = new_value_node; - - return true; -} - -bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; - return false; - } - auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); - abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); - new_value_node->set_abstract(abs_type); - anfnode_build_map_[value_node_name] = new_value_node; - return true; -} - -bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name, - const onnx::AttributeProto &attr_proto) { - if (!attr_proto.has_ref_attr_name()) { - MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; - return false; - } - const std::string &ref_attr_name = attr_proto.ref_attr_name(); - string type; - std::size_t pos(0); - if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("scalar:").length() - 1); - } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("type:").length() - 1); - } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("tensor:").length() - 1); - } - std::unordered_map kv; - for (int i = 0; i < attr_proto.tensors_size(); i++) { - const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); - auto attr_name = attr_tensor.name(); - switch (kParseTypeSwitchMap[type]) { - case FORM_PARSE_TYPE: { - return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); - } - case FORM_PARSE_SCALAR: { - auto res = ObtainCNodeAttrInScalarForm(attr_tensor); - kv.insert(std::pair(attr_tensor.name(), res)); - break; - } - case FORM_PARSE_TENSOR: { - return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); - } - default: - MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; - return false; - } - } - - ValueNodePtr new_value_node; - if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { - if (kv.size() == 1) { - auto iter = kv.begin(); - new_value_node = NewValueNode(iter->second); - new_value_node->set_abstract(iter->second->ToAbstract()); - } else { - auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); - new_value_node = NewValueNode(value_ptr); - new_value_node->set_abstract(value_ptr->ToAbstract()); - } - anfnode_build_map_[value_node_name] = new_value_node; - } - return true; -} - -bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { - const std::string &value_node_name = node_proto.output(0); - const onnx::AttributeProto &attr_proto = node_proto.attribute(0); - if (!attr_proto.has_ref_attr_name()) { - MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; - return false; - } - return GetAttrValueForValueNode(value_node_name, attr_proto); -} - -std::unordered_map MSANFModelParser::GetAbstractForCNode( - const onnx::AttributeProto &attr_proto) { - std::unordered_map kv; - for (int i = 0; i < attr_proto.tensors_size(); ++i) { - ShapeVector shape_vec; - const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); - for (int j = 0; j < attr_tensor.dims_size(); ++j) { - shape_vec.push_back(attr_tensor.dims(j)); - } - tensor::TensorPtr tensor_info = - std::make_shared(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); - MS_EXCEPTION_IF_NULL(tensor_info); - auto abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(abstract); - kv.insert(std::pair(attr_tensor.name(), abstract)); - } - return kv; -} - -CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::NodeProto &node_proto) { - MS_EXCEPTION_IF_NULL(outputFuncGraph); - if (!node_proto.has_op_type()) { - MS_LOG(ERROR) << "Get CNode op_type failed!"; - return nullptr; - } - const std::string &node_name = node_proto.output(0); - const std::string &fullname_with_scope = node_proto.domain(); - const std::string &node_type = node_proto.op_type(); - PrimitivePtr prim = std::make_shared(node_type); - MS_EXCEPTION_IF_NULL(prim); - prim->set_instance_name(node_type); - - std::unordered_map kv; - string shape_ref_attr_name; - for (int i = 0; i < node_proto.attribute_size(); ++i) { - const onnx::AttributeProto &attr_proto = node_proto.attribute(i); - if (attr_proto.ref_attr_name().find("shape:") != string::npos) { - shape_ref_attr_name = attr_proto.ref_attr_name(); - kv = GetAbstractForCNode(attr_proto); - continue; - } - if (!GetAttrValueForCNode(prim, attr_proto)) { - MS_LOG(ERROR) << "Get CNode attr failed!"; - return nullptr; - } - } - - std::vector inputs; - inputs.clear(); - inputs.push_back(NewValueNode(prim)); - for (int i = 0; i < node_proto.input_size(); ++i) { - const std::string &input_name = node_proto.input(i); - if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { - MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; - return nullptr; - } - inputs.push_back(anfnode_build_map_[input_name]); - } - CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(cnode_ptr); - if (0 == kv.size()) { - AbstractBasePtrList elem; - for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { - elem.push_back(cnode_ptr->input(index)->abstract()); - } - cnode_ptr->set_abstract(std::make_shared(elem)); - } else if (1 == kv.size()) { - std::unordered_map::iterator iter = kv.begin(); - cnode_ptr->set_abstract(iter->second); - } else { - auto abstract = ParserAttrShape(shape_ref_attr_name, kv); - cnode_ptr->set_abstract(abstract); - } - cnode_ptr->set_fullname_with_scope(fullname_with_scope); - anfnode_build_map_[node_name] = cnode_ptr; - return cnode_ptr; -} - -bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const CNodePtr &cnode_ptr) { - MS_EXCEPTION_IF_NULL(outputFuncGraph); - MS_EXCEPTION_IF_NULL(cnode_ptr); - std::vector inputs; - if (importProto.output_size() > 1) { - inputs.clear(); - inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - AbstractBasePtrList elem; - for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { - const onnx::ValueInfoProto &output_node = importProto.output(out_size); - const std::string &out_tuple = output_node.name(); - inputs.push_back(anfnode_build_map_[out_tuple]); - elem.push_back(anfnode_build_map_[out_tuple]->abstract()); - } - auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); - maketuple_ptr->set_abstract(std::make_shared(elem)); - inputs.clear(); - inputs.push_back(NewValueNode(prim::kPrimReturn)); - inputs.push_back(maketuple_ptr); - auto return_node = outputFuncGraph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(return_node); - outputFuncGraph->set_return(return_node); - MS_LOG(INFO) << "Construct funcgraph finined, all success."; - } else { - const onnx::ValueInfoProto &output_node = importProto.output(0); - const onnx::TypeProto &output_typeproto = output_node.type(); - ShapeVector output_shape; - for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { - output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); - } - inputs.clear(); - inputs.push_back(NewValueNode(prim::kPrimReturn)); - inputs.push_back(cnode_ptr); - auto return_node = outputFuncGraph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(return_node); - outputFuncGraph->set_return(return_node); - MS_LOG(INFO) << "Construct funcgraph finined, all success!"; - } - return true; -} - -bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { - MS_EXCEPTION_IF_NULL(outputFuncGraph); - MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); - CNodePtr cnode_ptr = nullptr; - for (int i = 0; i < importProto.node_size(); ++i) { - const onnx::NodeProto &node_proto = importProto.node(i); - const std::string &node_type = node_proto.op_type(); - if (node_type == kConstantValueNode) { - if (!BuildValueNodeForFuncGraph(node_proto)) { - MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; - return false; - } - continue; - } - cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); - if (cnode_ptr == nullptr) { - MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; - return false; - } - } - - BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); - return true; -} - -bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { - MS_EXCEPTION_IF_NULL(outputFuncGraph); - GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); - MS_EXCEPTION_IF_NULL(debug_info_ptr); - if (importProto.has_name()) { - debug_info_ptr->set_name(importProto.name()); - } else { - MS_LOG(ERROR) << "FuncGraph under converting has not name!"; - } - - if (!ImportParametersForGraph(outputFuncGraph, importProto)) { - return false; - } - return ImportNodesForGraph(outputFuncGraph, importProto); -} - -bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto) { - if (!model_proto.has_producer_name()) { - MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; - return false; - } - producer_name_ = model_proto.producer_name(); - MS_LOG(INFO) << "producer_name :" << producer_name_; - - if (!model_proto.has_model_version()) { - MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; - return false; - } - model_version_ = model_proto.model_version(); - MS_LOG(INFO) << "producer_version : " << model_version_; - - if (!model_proto.has_ir_version()) { - MS_LOG(ERROR) << "Parse model version from pb file failed!"; - return false; - } - ir_version_ = model_proto.ir_version(); - MS_LOG(INFO) << "ir_version :" << ir_version_; - return true; -} - -FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) { - FuncGraphPtr dstGraph = std::make_shared(); - MS_EXCEPTION_IF_NULL(dstGraph); - if (!MSANFParseModelConfigureInfo(model_proto)) { - MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; - } - const onnx::GraphProto &graphBuild = model_proto.graph(); - if (!BuildFuncGraph(dstGraph, graphBuild)) { - MS_LOG(ERROR) << "Build funcgraph failed!"; - return nullptr; - } - MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; - return dstGraph; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/mind_ir.proto b/mindspore/ccsrc/utils/mind_ir.proto new file mode 100644 index 0000000000..2c38198ab0 --- /dev/null +++ b/mindspore/ccsrc/utils/mind_ir.proto @@ -0,0 +1,119 @@ +syntax = "proto2"; +package mind_ir; + +message AttributeProto { + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + UINT8 = 2; + INT8 = 3; + UINT16 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + STRING = 8; + BOOL = 9; + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; + COMPLEX128 = 15; + BFLOAT16 = 16; + TENSOR = 17; + GRAPH = 18; + TENSORS = 19; + } + optional string name = 1; + optional float f = 2; + optional int64 i = 3; + optional double d = 4; + optional bytes s = 5; + optional TensorProto t = 6; + optional GraphProto g = 7; + repeated float floats = 8; + repeated double doubles = 9; + repeated int64 ints = 10; + repeated bytes strings = 11; + repeated TensorProto tensors = 12; + repeated GraphProto graphs = 13; + optional string doc_string = 14; + optional string ref_attr_name = 15; + optional AttributeType type = 16; +} + + +message ValueInfoProto { + optional string name = 1; + repeated TensorProto tensor = 2; + optional string doc_string = 3; + optional string denotation = 4; +} + + +message NodeProto { + repeated string input = 1; + repeated string output = 2; + optional string name = 3; + optional string op_type = 4; + repeated AttributeProto attribute = 5; + optional string doc_string = 6; + optional string domain = 7; +} + + +message ModelProto { + optional string ir_version = 1; + optional string producer_name = 2; + optional string producer_version = 3; + optional string domain = 4; + optional string model_version = 5; + optional string doc_string = 6; + optional GraphProto graph = 7; +} + + +message GraphProto { + repeated NodeProto node = 1; + optional string name = 2; + repeated TensorProto parameter = 3; + optional string doc_string = 4; + repeated ValueInfoProto input = 5; + repeated ValueInfoProto output = 6; +} + + +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; + COMPLEX128 = 15; + BFLOAT16 = 16; + FLOAT64 = 17; + } + repeated int64 dims = 1; + optional int32 data_type = 2; + repeated float float_data = 3; + repeated int32 int32_data = 4; + repeated bytes string_data = 5; + repeated int64 int64_data = 6; + optional string name = 7; + optional string doc_string = 8; + optional bytes raw_data = 9; + repeated double double_data = 10; + repeated uint64 uint64_data = 11; +} diff --git a/mindspore/core/CMakeLists.txt b/mindspore/core/CMakeLists.txt index 8e018df2e4..18c7141f28 100644 --- a/mindspore/core/CMakeLists.txt +++ b/mindspore/core/CMakeLists.txt @@ -15,6 +15,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "c_ops/*.cc" "ir/*.cc" "utils/*.cc" + "load_mindir/*.cc" ) set_property(SOURCE ${CORE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_CORE) add_library(mindspore_core STATIC ${CORE_SRC_LIST}) diff --git a/mindspore/core/c_ops/add.cc b/mindspore/core/c_ops/add.cc index 7646f1e47a..c2c2e50e84 100644 --- a/mindspore/core/c_ops/add.cc +++ b/mindspore/core/c_ops/add.cc @@ -50,4 +50,5 @@ AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const Primit InferShape(primitive, input_args)->shape()); } REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer); +REGISTER_PRIMITIVE_C(TensorAdd); } // namespace mindspore diff --git a/mindspore/core/c_ops/avg_pool.cc b/mindspore/core/c_ops/avg_pool.cc index d29df4198c..323c4ddc61 100644 --- a/mindspore/core/c_ops/avg_pool.cc +++ b/mindspore/core/c_ops/avg_pool.cc @@ -102,4 +102,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv InferShape(primitive, input_args)->shape()); } REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer); +REGISTER_PRIMITIVE_C(AvgPool); } // namespace mindspore diff --git a/mindspore/core/c_ops/conv2d.cc b/mindspore/core/c_ops/conv2d.cc index 4ab9992bfc..76f637ffe8 100644 --- a/mindspore/core/c_ops/conv2d.cc +++ b/mindspore/core/c_ops/conv2d.cc @@ -193,4 +193,5 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive Conv2dInferShape(primitive, input_args)->shape()); } REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); +REGISTER_PRIMITIVE_C(Conv2D); } // namespace mindspore diff --git a/mindspore/core/c_ops/op_utils.h b/mindspore/core/c_ops/op_utils.h index 8955f2a61c..ad1d482d0b 100644 --- a/mindspore/core/c_ops/op_utils.h +++ b/mindspore/core/c_ops/op_utils.h @@ -33,7 +33,7 @@ constexpr auto kPad = "pad"; constexpr auto kPads = "pads"; constexpr auto kMode = "mode"; constexpr auto kGroup = "group"; -constexpr auto kOutputChannel = "output_channel"; +constexpr auto kOutputChannel = "out_channel"; constexpr auto kPadList = "pad_list"; constexpr auto kAxis = "axis"; diff --git a/mindspore/core/c_ops/primitive_c.cc b/mindspore/core/c_ops/primitive_c.cc index cc1533e8e1..690ebd121f 100644 --- a/mindspore/core/c_ops/primitive_c.cc +++ b/mindspore/core/c_ops/primitive_c.cc @@ -31,4 +31,13 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { auto infer_function = iter->second.impl_; return infer_function(nullptr, shared_from_base(), abstract_list); } + +OpPrimCRegister &OpPrimCRegister::GetInstance() { + static OpPrimCRegister instance; + return instance; +} + +std::map OpPrimCRegister::GetPrimCMap() { return op_primc_fns_; } +void OpPrimCRegister::SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn) { op_primc_fns_[name] = fn; } + } // namespace mindspore diff --git a/mindspore/core/c_ops/primitive_c.h b/mindspore/core/c_ops/primitive_c.h index f85d4559f4..662fddd007 100644 --- a/mindspore/core/c_ops/primitive_c.h +++ b/mindspore/core/c_ops/primitive_c.h @@ -18,6 +18,8 @@ #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ #include #include +#include +#include #include "ir/primitive.h" #include "abstract/primitive_infer_map.h" #include "ir/value.h" @@ -32,5 +34,33 @@ class PrimitiveC : public Primitive { protected: void InitIOName(const std::vector &inputs_name, const std::vector &outputs_name); }; + +using OpPrimCDefineFunc = std::function()>; +class OpPrimCRegister { + public: + ~OpPrimCRegister() {} + static OpPrimCRegister &GetInstance(); + std::map GetPrimCMap(); + void SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn); + + private: + OpPrimCRegister() {} + std::map op_primc_fns_; +}; + +class OpPrimCRegisterHelper { + public: + OpPrimCRegisterHelper(const std::string &name, const OpPrimCDefineFunc &fn) { + OpPrimCRegister::GetInstance().SetPrimCMap(name, fn); + } + ~OpPrimCRegisterHelper() = default; +}; + +#define REGISTER_PRIMITIVE_C(name) \ + std::shared_ptr GetDefaultPrimC##name() { \ + auto out = std::make_shared(); \ + return out; \ + } \ + OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name); } // namespace mindspore #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ diff --git a/mindspore/core/c_ops/relu6.cc b/mindspore/core/c_ops/relu6.cc index 90eb49f3e2..1d962cd258 100644 --- a/mindspore/core/c_ops/relu6.cc +++ b/mindspore/core/c_ops/relu6.cc @@ -49,4 +49,5 @@ AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitiveP Relu6InferShape(primitive, input_args)->shape()); } REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer); +REGISTER_PRIMITIVE_C(Relu6); } // namespace mindspore diff --git a/mindspore/core/c_ops/reshape.cc b/mindspore/core/c_ops/reshape.cc index 3646d46c2b..43028366b7 100644 --- a/mindspore/core/c_ops/reshape.cc +++ b/mindspore/core/c_ops/reshape.cc @@ -41,4 +41,5 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv } REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer); +REGISTER_PRIMITIVE_C(Reshape); } // namespace mindspore diff --git a/mindspore/core/c_ops/reshape.h b/mindspore/core/c_ops/reshape.h index 5850e5b313..1ddeb26182 100644 --- a/mindspore/core/c_ops/reshape.h +++ b/mindspore/core/c_ops/reshape.h @@ -36,7 +36,7 @@ class Reshape : public PrimitiveC { AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); -using PrimTensorAddPtr = std::shared_ptr; +using PrimReshapePtr = std::shared_ptr; } // namespace mindspore #endif // MINDSPORE_CORE_C_OPS_RESHAPE_H_ diff --git a/mindspore/core/c_ops/softmax.cc b/mindspore/core/c_ops/softmax.cc index 5c6e865067..8482d07cbb 100644 --- a/mindspore/core/c_ops/softmax.cc +++ b/mindspore/core/c_ops/softmax.cc @@ -75,4 +75,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv } REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer); +REGISTER_PRIMITIVE_C(Softmax); } // namespace mindspore diff --git a/mindspore/core/c_ops/squeeze.cc b/mindspore/core/c_ops/squeeze.cc index bc8808c8c0..dfd31b00a7 100644 --- a/mindspore/core/c_ops/squeeze.cc +++ b/mindspore/core/c_ops/squeeze.cc @@ -76,4 +76,5 @@ AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const Primitiv } REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer); +REGISTER_PRIMITIVE_C(Squeeze); } // namespace mindspore diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc new file mode 100644 index 0000000000..6f16cc6be3 --- /dev/null +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -0,0 +1,854 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "load_mindir/anf_model_parser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/tensor.h" +#include "ir/param_info.h" +#include "c_ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/log_adapter.h" +#include "utils/shape_utils.h" + +using std::string; + +namespace mindspore { +static constexpr char kConstantValueNode[] = "Constant"; +static constexpr char kCNodeShapeAttr[] = "shape"; +static constexpr char kCNodeShape1Attr[] = "shape1"; +static constexpr char kCNodeShape2Attr[] = "shape2"; +enum ParseForm : int { + FORM_PARSE_TYPE = 0, + FORM_PARSE_SCALAR = 1, + FORM_PARSE_TENSOR = 2, + FORM_PARSE_NONE = 3, + FORM_PARSE_UNDEFINE = 4, +}; + +static std::map kParseTypeSwitchMap{{"type", FORM_PARSE_TYPE}, + {"scalar", FORM_PARSE_SCALAR}, + {"tensor", FORM_PARSE_TENSOR}, + {"none", FORM_PARSE_NONE}, + {"", FORM_PARSE_UNDEFINE}}; + +static std::unordered_map kDefaultValueSwitchMap{ + {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool}, + {mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8}, + {mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16}, + {mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32}, + {mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64}, + {mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16}, + {mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64}, + {mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, + {mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64}, + {mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {mind_ir::TensorProto_DataType_STRING, kObjectTypeString}, +}; + +template +std::shared_ptr ParserAttr(const std::string &str, const std::unordered_map &kv) { + std::stack rules; + std::stack

value; + int count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector

vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + } + } + } + return {}; +} + +template +std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("scalar:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + auto result = ParserAttr(str, kv); + return result; +} + +std::shared_ptr ParserAttrShape( + const std::string &attr_name, const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("shape:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + + auto result = ParserAttr(str, kv); + return result; +} + +std::string ParseParameterName(const string &name) { + string delimiter = ":"; + size_t pos(0); + if ((pos = name.find(delimiter)) != string::npos) { + return name.substr(pos + 1, string::npos - (pos + 1)); + } + return name; +} + +std::string ParseCNodeName(const string &name) { + string delimiter = ":"; + size_t pos = name.find(delimiter); + size_t end_pos = name.find_last_of(delimiter); + + if (pos != string::npos && end_pos != string::npos && pos != end_pos) { + return name.substr(pos + 1, end_pos - (pos + 1)); + } + return name; +} + +#define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \ + auto value = static_cast(attr_proto.ints(index)); \ + return MakeValue(value); \ + } \ + ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) { \ + auto value = static_cast(attr_proto.i()); \ + return MakeValue(value); \ + } + +#define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \ + auto value = static_cast(attr_proto.type##s(index)); \ + return MakeValue(value); \ + } + +PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t) +PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool) + +PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double) +PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float) +PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string) + +ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) { + auto value = static_cast(attr_proto.s()); + return MakeValue(value); +} + +ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) { + auto value = static_cast(attr_proto.f()); + return MakeValue(value); +} + +ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) { + auto value = static_cast(attr_proto.d()); + return MakeValue(value); +} + +tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) { + ShapeVector shape; + for (int i = 0; i < tensor_proto.dims_size(); ++i) { + shape.push_back(tensor_proto.dims(i)); + } + + if (!tensor_proto.has_data_type()) { + MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!"; + return nullptr; + } + if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!"; + return nullptr; + } + + tensor::TensorPtr tensor_info = + std::make_shared(kDefaultValueSwitchMap[tensor_proto.data_type()], shape); + MS_EXCEPTION_IF_NULL(tensor_info); + return tensor_info; +} + +bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, + const mind_ir::TensorProto ¶meter_proto) { + MS_EXCEPTION_IF_NULL(node); + + if (!parameter_proto.has_name()) { + MS_LOG(ERROR) << "mind_ir TensorProto has no name!"; + return false; + } + string debug_info_name = ParseParameterName(parameter_proto.name()); + auto debug_info_ptr = std::make_shared(debug_info_name); + node->set_debug_info(debug_info_ptr); + node->set_name(parameter_proto.name()); + + tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto); + auto tensor_abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(tensor_abstract); + node->set_abstract(tensor_abstract); + + std::string initial_data = parameter_proto.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); + MS_EXCEPTION_IF_NULL(tensor_data_buf); + auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error for build parameter, errorno " << ret; + } + + node->set_default_param(tensor_info); + + anfnode_build_map_[parameter_proto.name()] = node; + return true; +} + +bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) { + MS_EXCEPTION_IF_NULL(node); + + if (!value_proto.has_name()) { + MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!"; + return false; + } + string debug_info_name = ParseParameterName(value_proto.name()); + auto debug_info_ptr = std::make_shared(debug_info_name); + node->set_debug_info(debug_info_ptr); + node->set_name(value_proto.name()); + + const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); + + tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); + auto tensor_abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(tensor_abstract); + node->set_abstract(tensor_abstract); + + anfnode_build_map_[value_proto.name()] = node; + return true; +} + +bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const mind_ir::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size(); + for (int i = 0; i < importProto.parameter_size(); ++i) { + const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i); + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) { + MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; + return false; + } + } + + MS_LOG(INFO) << "All inputs size is: " << importProto.input_size(); + for (int i = 0; i < importProto.input_size(); ++i) { + const mind_ir::ValueInfoProto &input_proto = importProto.input(i); + if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { + MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i; + return false; + } + } + return true; +} + +bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) { + MS_EXCEPTION_IF_NULL(prim); + const int attr_tensor_type = attr_proto.tensors(0).data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; + return false; + } + prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + return true; +} + +ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) { + const int attr_type = attr_proto.type(); + switch (attr_type) { + case mind_ir::AttributeProto_AttributeType_STRING: { + return ParseAttrInScalar_string_string(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_INT8: { + return ParseAttrInScalar_int8_t_int8_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_INT16: { + return ParseAttrInScalar_int16_t_int16_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_INT32: { + return ParseAttrInScalar_int32_t_int32_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_INT64: { + return ParseAttrInScalar_int64_t_int64_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_UINT8: { + return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_UINT16: { + return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_UINT32: { + return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_UINT64: { + return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_FLOAT: { + return ParseAttrInScalar_float_float(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_DOUBLE: { + return ParseAttrInScalar_double_double(attr_proto, index); + } + case mind_ir::AttributeProto_AttributeType_BOOL: { + return ParseAttrInScalar_int32_t_bool(attr_proto, index); + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type; + return {}; + } + return {}; +} + +void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, + std::unordered_map *multi_value_map) { + string name; + for (int i = 0; i < attr_proto.ints_size(); i++) { + auto res = ParseAttrInScalarForm(attr_proto, i); + name = "value" + std::to_string(i + 1); + multi_value_map->insert(std::pair(name, res)); + } + for (int i = 0; i < attr_proto.doubles_size(); i++) { + auto res = ParseAttrInScalarForm(attr_proto, i); + name = "value" + std::to_string(i + 1); + multi_value_map->insert(std::pair(name, res)); + } + for (int i = 0; i < attr_proto.floats_size(); i++) { + auto res = ParseAttrInScalarForm(attr_proto, i); + name = "value" + std::to_string(i + 1); + multi_value_map->insert(std::pair(name, res)); + } + for (int i = 0; i < attr_proto.strings_size(); i++) { + auto res = ParseAttrInScalarForm(attr_proto, i); + name = "value" + std::to_string(i + 1); + multi_value_map->insert(std::pair(name, res)); + } +} + +ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) { + const int attr_type = attr_proto.type(); + switch (attr_type) { + case mind_ir::AttributeProto_AttributeType_STRING: { + return ParseAttrInSingleScalar_string_string(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_INT8: { + return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_INT16: { + return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_INT32: { + return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_INT64: { + return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_UINT8: { + return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_UINT16: { + return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_UINT32: { + return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_UINT64: { + return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_FLOAT: { + return ParseAttrInSingleScalar_float_float(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_DOUBLE: { + return ParseAttrInSingleScalar_double_double(attr_proto); + } + case mind_ir::AttributeProto_AttributeType_BOOL: { + return ParseAttrInSingleScalar_int32_t_bool(attr_proto); + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type; + return {}; + } + return {}; +} + +bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, + const mind_ir::AttributeProto &attr_proto) { + MS_EXCEPTION_IF_NULL(prim); + const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0); + const int attr_tensor_type = attr_tensor.data_type(); + ShapeVector shape; + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + const std::string &tensor_buf = attr_tensor.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } + prim->AddAttr(attr_proto.name(), MakeValue(tensor_info)); + return true; +} + +bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) { + MS_EXCEPTION_IF_NULL(prim); + const std::string &attr_name = attr_proto.name(); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type = ""; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map multi_value_map; + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + ObtainCNodeAttrInTypeForm(prim, attr_proto); + break; + } + case FORM_PARSE_SCALAR: { + std::size_t value_pos(0); + if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) { + auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto); + prim->AddAttr(attr_name, res); + break; + } + ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map); + break; + } + case FORM_PARSE_TENSOR: { + ObtainCNodeAttrInTensorForm(prim, attr_proto); + break; + } + default: + MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name; + return false; + } + + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) { + if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) { + auto value_tuple_ptr = ParserScalarAttrValue(ref_attr_name, multi_value_map); + prim->AddAttr(attr_name, value_tuple_ptr); + } else { + auto value_list_ptr = ParserScalarAttrValue(ref_attr_name, multi_value_map); + prim->AddAttr(attr_name, value_list_ptr); + } + } + return true; +} + +bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name, + const mind_ir::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + ShapeVector shape; + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + const std::string &tensor_buf = attr_tensor.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } + + auto new_value_node = NewValueNode(MakeValue(tensor_info)); + MS_EXCEPTION_IF_NULL(new_value_node); + auto tensor_abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(tensor_abstract); + new_value_node->set_abstract(tensor_abstract); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name, + const mind_ir::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; + return false; + } + auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); + new_value_node->set_abstract(abs_type); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name, + const mind_ir::AttributeProto &attr_proto) { + auto new_value_node = NewValueNode(kNone); + MS_EXCEPTION_IF_NULL(new_value_node); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name, + const mind_ir::AttributeProto &attr_proto) { + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type = ""; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } else if (ref_attr_name == "none") { + type = ref_attr_name; + } + + ValueNodePtr new_value_node; + std::unordered_map multi_value_map; + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0)); + break; + } + case FORM_PARSE_SCALAR: { + std::size_t value_pos(0); + if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) { + auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto); + new_value_node = NewValueNode(res); + new_value_node->set_abstract(res->ToAbstract()); + anfnode_build_map_[value_node_name] = new_value_node; + break; + } + ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map); + break; + } + case FORM_PARSE_TENSOR: { + ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0)); + break; + } + case FORM_PARSE_NONE: { + ObtainValueNodeInNoneForm(value_node_name, attr_proto); + break; + } + default: + MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name; + return false; + } + + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) { + if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) { + auto value_tuple_ptr = ParserScalarAttrValue(ref_attr_name, multi_value_map); + new_value_node = NewValueNode(value_tuple_ptr); + new_value_node->set_abstract(value_tuple_ptr->ToAbstract()); + } else { + auto value_list_ptr = ParserScalarAttrValue(ref_attr_name, multi_value_map); + new_value_node = NewValueNode(value_list_ptr); + new_value_node->set_abstract(value_list_ptr->ToAbstract()); + } + anfnode_build_map_[value_node_name] = new_value_node; + } + return true; +} + +bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) { + const std::string &value_node_name = node_proto.output(0); + const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; + return false; + } + return GetAttrValueForValueNode(value_node_name, attr_proto); +} + +std::unordered_map MSANFModelParser::GetAbstractForCNode( + const mind_ir::AttributeProto &attr_proto) { + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); ++i) { + ShapeVector shape_vec; + const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i); + for (int j = 0; j < attr_tensor.dims_size(); ++j) { + shape_vec.push_back(attr_tensor.dims(j)); + } + tensor::TensorPtr tensor_info = + std::make_shared(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); + MS_EXCEPTION_IF_NULL(tensor_info); + auto abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(abstract); + kv.insert(std::pair(attr_tensor.name(), abstract)); + } + return kv; +} + +CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const mind_ir::NodeProto &node_proto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + if (!node_proto.has_op_type()) { + MS_LOG(ERROR) << "Get CNode op_type failed!"; + return nullptr; + } + const std::string &node_name = node_proto.output(0); + const std::string &fullname_with_scope = node_proto.domain(); + const std::string &node_type = node_proto.op_type(); + + std::shared_ptr prim; + auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap(); + if (op_primc_fns.find(node_type) != op_primc_fns.end()) { + prim = op_primc_fns[node_type](); + } else { + prim = std::make_shared(node_type); + prim->set_instance_name(node_type); + } + MS_EXCEPTION_IF_NULL(prim); + + std::unordered_map kv; + string shape_ref_attr_name; + for (int i = 0; i < node_proto.attribute_size(); ++i) { + const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); + if (attr_proto.ref_attr_name().find("shape:") != string::npos) { + shape_ref_attr_name = attr_proto.ref_attr_name(); + kv = GetAbstractForCNode(attr_proto); + continue; + } + + if (!GetAttrValueForCNode(prim, attr_proto)) { + MS_LOG(ERROR) << "Get CNode attr failed!"; + return nullptr; + } + } + + std::vector inputs; + inputs.clear(); + for (int i = 0; i < node_proto.input_size(); ++i) { + const std::string &input_name = node_proto.input(i); + if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { + MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; + return nullptr; + } + + inputs.push_back(anfnode_build_map_[input_name]); + } + + auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + + if (0 == kv.size()) { + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + elem.push_back(cnode_ptr->input(index)->abstract()); + } + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (1 == kv.size()) { + std::unordered_map::iterator iter = kv.begin(); + cnode_ptr->set_abstract(iter->second); + } else { + auto abstract = ParserAttrShape(shape_ref_attr_name, kv); + cnode_ptr->set_abstract(abstract); + } + + string debug_info_name = ParseCNodeName(node_name); + auto debug_info_ptr = std::make_shared(debug_info_name); + cnode_ptr->set_debug_info(debug_info_ptr); + cnode_ptr->set_fullname_with_scope(fullname_with_scope); + + anfnode_build_map_[node_name] = cnode_ptr; + return cnode_ptr; +} + +bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_EXCEPTION_IF_NULL(cnode_ptr); + std::vector inputs; + if (importProto.output_size() > 1) { + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + AbstractBasePtrList elem; + for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { + const mind_ir::ValueInfoProto &output_node = importProto.output(out_size); + const std::string &out_tuple = output_node.name(); + inputs.push_back(anfnode_build_map_[out_tuple]); + elem.push_back(anfnode_build_map_[out_tuple]->abstract()); + } + auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); + maketuple_ptr->set_abstract(std::make_shared(elem)); + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(maketuple_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success."; + } else { + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(cnode_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success!"; + } + return true; +} + +bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const mind_ir::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); + CNodePtr cnode_ptr = nullptr; + for (int i = 0; i < importProto.node_size(); ++i) { + const mind_ir::NodeProto &node_proto = importProto.node(i); + const std::string &node_type = node_proto.op_type(); + if (node_type == kConstantValueNode) { + if (!BuildValueNodeForFuncGraph(node_proto)) { + MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; + return false; + } + continue; + } + cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; + return false; + } + } + + BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); + return true; +} + +bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); + MS_EXCEPTION_IF_NULL(debug_info_ptr); + if (importProto.has_name()) { + debug_info_ptr->set_name(importProto.name()); + } else { + MS_LOG(ERROR) << "FuncGraph under converting has not name!"; + } + + if (!ImportParametersForGraph(outputFuncGraph, importProto)) { + MS_LOG(ERROR) << "import parameters for graph fail!"; + return false; + } + return ImportNodesForGraph(outputFuncGraph, importProto); +} + +bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto) { + if (!model_proto.has_producer_name()) { + MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; + return false; + } + producer_name_ = model_proto.producer_name(); + MS_LOG(INFO) << "producer_name :" << producer_name_; + + if (!model_proto.has_model_version()) { + MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; + return false; + } + model_version_ = model_proto.model_version(); + MS_LOG(INFO) << "producer_version : " << model_version_; + + if (!model_proto.has_ir_version()) { + MS_LOG(ERROR) << "Parse model version from pb file failed!"; + return false; + } + ir_version_ = model_proto.ir_version(); + MS_LOG(INFO) << "ir_version :" << ir_version_; + return true; +} + +FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { + FuncGraphPtr dstGraph = std::make_shared(); + MS_EXCEPTION_IF_NULL(dstGraph); + if (!MSANFParseModelConfigureInfo(model_proto)) { + MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; + } + const mind_ir::GraphProto &graphBuild = model_proto.graph(); + if (!BuildFuncGraph(dstGraph, graphBuild)) { + MS_LOG(ERROR) << "Build funcgraph failed!"; + return nullptr; + } + MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; + return dstGraph; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h similarity index 50% rename from mindspore/ccsrc/utils/load_onnx/anf_model_parser.h rename to mindspore/core/load_mindir/anf_model_parser.h index 5dc0b17b35..fa03d0e82e 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -14,63 +14,62 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H -#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H +#ifndef MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H +#define MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H #include #include #include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "ir/func_graph.h" -#include "proto/onnx.pb.h" +#include "proto/mind_ir.pb.h" namespace mindspore { -namespace lite { using int32 = int32_t; using int64 = int64_t; using uint64 = uint64_t; class MSANFModelParser { public: - MSANFModelParser() : producer_name_(""), model_version_(0), ir_version_(0) {} + MSANFModelParser() : producer_name_(""), model_version_(""), ir_version_("") {} ~MSANFModelParser() = default; - FuncGraphPtr Parse(const onnx::ModelProto &model_proto); - bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto); + FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto); + bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto); std::string GetProducerName() { return producer_name_; } - int GetProducerVersion() { return model_version_; } - int GetIrVersion() { return ir_version_; } + std::string GetProducerVersion() { return model_version_; } + std::string GetIrVersion() { return ir_version_; } private: - bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); - bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); - bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); - bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); - CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); - bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); + bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); + bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); + bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto); + bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto); + tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto); + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr); - bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); - bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor); - ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); - bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor); - bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); - bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - - bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_tensor); - bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); + void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, + std::unordered_map *multi_value_map); + ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index); + ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); + bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto); + bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor); + bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); + bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); std::unordered_map GetAbstractForCNode( - const onnx::AttributeProto &attr_proto); + const mind_ir::AttributeProto &attr_proto); std::string producer_name_; - int model_version_; - int ir_version_; + std::string model_version_; + std::string ir_version_; std::unordered_map anfnode_build_map_; - std::map default_para_map_; }; -} // namespace lite } // namespace mindspore -#endif // MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H +#endif // MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc new file mode 100644 index 0000000000..4e5cfc5446 --- /dev/null +++ b/mindspore/core/load_mindir/load_model.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "load_mindir/load_model.h" +#include +#include +#include + +#include "load_mindir/anf_model_parser.h" + +using std::string; +using std::vector; + +namespace mindspore { +std::shared_ptr> ReadProtoFile(const std::string &file) { + if (file.empty()) { + MS_LOG(ERROR) << "file is nullptr"; + return nullptr; + } + + char real_path[PATH_MAX] = {0}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file; + return nullptr; + } +#else + if (realpath(file.c_str(), real_path) == nullptr) { + MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file; + return nullptr; + } +#endif + + std::ifstream ifs(real_path); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << real_path << " is not exist"; + return nullptr; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << real_path << "open failed"; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + std::shared_ptr> buf(new (std::nothrow) std::vector(size)); + if (buf == nullptr) { + MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf->data(), size); + ifs.close(); + + return buf; +} + +std::shared_ptr RunLoadMindIR(const std::string &file_name) { + auto graphBuf = ReadProtoFile(file_name); + if (graphBuf == nullptr) { + MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str(); + return nullptr; + } + + try { + auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size()); + return graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + return nullptr; + } +} + +std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) { + MS_EXCEPTION_IF_NULL(buf); + std::string str((const char *)buf, buf_size); + mind_ir::ModelProto model_; + if (!model_.ParseFromString(str)) { + MS_LOG(ERROR) << "Parse model from buffer fail!"; + } + MSANFModelParser model_parser; + FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); + return dstgraph_ptr; +} + +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/load_onnx/anf_converter.h b/mindspore/core/load_mindir/load_model.h similarity index 53% rename from mindspore/ccsrc/utils/load_onnx/anf_converter.h rename to mindspore/core/load_mindir/load_model.h index 4f5fe3971f..f7681a5d49 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_converter.h +++ b/mindspore/core/load_mindir/load_model.h @@ -13,27 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef MINDSPORE_CORE_LOAD_MODEL_H +#define MINDSPORE_CORE_LOAD_MODEL_H -#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H -#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H +#include #include #include -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "proto/onnx.pb.h" + +#include "proto/mind_ir.pb.h" #include "ir/func_graph.h" namespace mindspore { -namespace lite { -class AnfConverter { - public: - static std::shared_ptr RunAnfConverter(const std::string &file_path); - static std::shared_ptr RunAnfConverter(const char *buf, const size_t buf_size); - - private: - static void Trim(std::string *input); - static int ValidateFileStr(const std::string &modelFile, std::string fileType); - static bool ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model); -}; -} // namespace lite +std::shared_ptr RunLoadMindIR(const std::string &file_name); +std::shared_ptr> ReadProtoFile(const std::string &file); +std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size); } // namespace mindspore -#endif +#endif // MINDSPORE_CORE_LOAD_MODEL_H diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 00173c3fec..6953b135b6 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -116,7 +116,9 @@ endif () file(GLOB PROTO_FILE "" ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto - ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) + ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto + ${CCSRC_DIR}/utils/mind_ir.proto) + ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) add_library(proto_mid OBJECT ${PROTO_SRCS}) set(TFLITE_FBS_FILES diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index ced97b027c..10dbf0a7ae 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -23,9 +23,12 @@ from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype from mindspore.common import dtype as mstype from mindspore import log as logger from mindspore.common.api import _executor +from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model +from mindspore.train.anf_ir_pb2 import ModelProto as anf_model from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo + def _convert_type(types): """ Convert from numpy type to tensor type. @@ -203,3 +206,32 @@ def check_value_type(arg_name, arg_value, valid_types): if not is_valid: raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, ' f'bug got {type(arg_value).__name__}.') + + +def read_proto(file_name, proto_format="MINDIR"): + """ + Read protobuf file. + + Args: + file_name (str): File name. + proto_format (str): Proto format. + + Returns: + Object, proto object. + """ + + if proto_format == "MINDIR": + model = mindir_model() + elif model_format == "ANF": + model = anf_model() + else: + raise ValueError("Unsupported proto format.") + + try: + with open(file_name, "rb") as f: + pb_content = f.read() + model.ParseFromString(pb_content) + except BaseException as e: + logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) + raise ValueError(e.__str__()) + return model