From: @wangnan39 Reviewed-by: @kingxian,@guoqi1024 Signed-off-by: @kingxiantags/v1.1.0
| @@ -125,7 +125,8 @@ if (ENABLE_DUMP_PROTO) | |||||
| "utils/lineage.proto" | "utils/lineage.proto" | ||||
| "utils/checkpoint.proto" | "utils/checkpoint.proto" | ||||
| "utils/print.proto" | "utils/print.proto" | ||||
| "utils/node_strategy.proto" | |||||
| "utils/node_strategy.proto" | |||||
| "utils/mind_ir.proto" | |||||
| ) | ) | ||||
| ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) | ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) | ||||
| @@ -156,7 +157,7 @@ endif() | |||||
| ## make sub objects | ## make sub objects | ||||
| set(SUB_COMP | set(SUB_COMP | ||||
| transform/graph_ir | transform/graph_ir | ||||
| transform/onnx | |||||
| transform/express_ir | |||||
| backend/optimizer | backend/optimizer | ||||
| backend/kernel_compiler | backend/kernel_compiler | ||||
| backend/session | backend/session | ||||
| @@ -344,13 +345,13 @@ if (ENABLE_MINDDATA) | |||||
| endif () | endif () | ||||
| # build inference | # build inference | ||||
| set(LOAD_ONNX_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_converter.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc | |||||
| set(LOAD_MINDIR_SRC | |||||
| ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc | |||||
| ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc | |||||
| ) | ) | ||||
| add_library(inference SHARED | add_library(inference SHARED | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc | ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc | ||||
| ${LOAD_ONNX_SRC} | |||||
| ${LOAD_MINDIR_SRC} | |||||
| ) | ) | ||||
| set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) | set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) | ||||
| @@ -20,11 +20,11 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include "include/inference.h" | #include "include/inference.h" | ||||
| #include "utils/load_onnx/anf_converter.h" | |||||
| #include "backend/session/session_basic.h" | #include "backend/session/session_basic.h" | ||||
| #include "backend/session/session_factory.h" | #include "backend/session/session_factory.h" | ||||
| #include "backend/session/executor_manager.h" | #include "backend/session/executor_manager.h" | ||||
| #include "base/base_ref_utils.h" | #include "base/base_ref_utils.h" | ||||
| #include "load_mindir/load_model.h" | |||||
| #include "backend/kernel_compiler/oplib/oplib.h" | #include "backend/kernel_compiler/oplib/oplib.h" | ||||
| #include "utils/context/context_extends.h" | #include "utils/context/context_extends.h" | ||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| @@ -58,46 +58,9 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev | |||||
| MSInferSession::MSInferSession() = default; | MSInferSession::MSInferSession() = default; | ||||
| MSInferSession::~MSInferSession() = default; | MSInferSession::~MSInferSession() = default; | ||||
| std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &file) { | |||||
| if (file.empty()) { | |||||
| MS_LOG(ERROR) << "file is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| std::string realPath = file; | |||||
| std::ifstream ifs(realPath); | |||||
| if (!ifs.good()) { | |||||
| MS_LOG(ERROR) << "file: " << realPath << " is not exist"; | |||||
| return nullptr; | |||||
| } | |||||
| if (!ifs.is_open()) { | |||||
| MS_LOG(ERROR) << "file: " << realPath << "open failed"; | |||||
| return nullptr; | |||||
| } | |||||
| ifs.seekg(0, std::ios::end); | |||||
| size_t size = ifs.tellg(); | |||||
| std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size)); | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc buf failed, file: " << realPath; | |||||
| ifs.close(); | |||||
| return nullptr; | |||||
| } | |||||
| ifs.seekg(0, std::ios::beg); | |||||
| ifs.read(buf->data(), size); | |||||
| ifs.close(); | |||||
| return buf; | |||||
| } | |||||
| Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { | Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { | ||||
| auto graphBuf = ReadFile(file_name); | |||||
| if (graphBuf == nullptr) { | |||||
| MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); | |||||
| return FAILED; | |||||
| } | |||||
| auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_); | |||||
| Py_Initialize(); | |||||
| auto graph = RunLoadMindIR(file_name); | |||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -213,6 +176,7 @@ Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &reques | |||||
| } | } | ||||
| inputs.push_back(input); | inputs.push_back(input); | ||||
| } | } | ||||
| auto ret = CheckModelInputs(model_id, inputs); | auto ret = CheckModelInputs(model_id, inputs); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed"; | MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed"; | ||||
| @@ -250,16 +214,6 @@ Status MSInferSession::FinalizeEnv() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) { | |||||
| try { | |||||
| auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); | |||||
| return anf_graph; | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "Inference LoadModel failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| void MSInferSession::RegAllOp() { | void MSInferSession::RegAllOp() { | ||||
| static std::mutex init_mutex; | static std::mutex init_mutex; | ||||
| static bool Initialized = false; | static bool Initialized = false; | ||||
| @@ -54,8 +54,6 @@ class MSInferSession : public InferSession { | |||||
| rtContext_t context_ = nullptr; | rtContext_t context_ = nullptr; | ||||
| #endif | #endif | ||||
| std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device); | |||||
| std::shared_ptr<std::vector<char>> ReadFile(const std::string &file); | |||||
| static void RegAllOp(); | static void RegAllOp(); | ||||
| string AjustTargetName(const std::string &device); | string AjustTargetName(const std::string &device); | ||||
| Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id); | Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id); | ||||
| @@ -1,7 +1,7 @@ | |||||
| # build mindspore_shared_lib | # build mindspore_shared_lib | ||||
| set(LOAD_ONNX_SRC | |||||
| ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_converter.cc | |||||
| ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc | |||||
| set(LOAD_MINDIR_SRC | |||||
| ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc | |||||
| ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc | |||||
| ) | ) | ||||
| file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") | file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") | ||||
| @@ -18,7 +18,7 @@ set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | |||||
| ${API_MS_INFER_SRC} | ${API_MS_INFER_SRC} | ||||
| ${API_ACL_SRC} | ${API_ACL_SRC} | ||||
| ${API_OPS_SRC} | ${API_OPS_SRC} | ||||
| ${LOAD_ONNX_SRC}) | |||||
| ${LOAD_MINDIR_SRC}) | |||||
| add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) | add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) | ||||
| set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") | set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") | ||||
| @@ -17,9 +17,9 @@ | |||||
| #include "cxx_api/model/acl/model_converter.h" | #include "cxx_api/model/acl/model_converter.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "utils/load_onnx/anf_converter.h" | |||||
| #include "transform/graph_ir/convert.h" | #include "transform/graph_ir/convert.h" | ||||
| #include "transform/graph_ir/graph_runner.h" | #include "transform/graph_ir/graph_runner.h" | ||||
| #include "core/load_mindir/load_model.h" | |||||
| #include "mindspore/core/utils/ms_context.h" | #include "mindspore/core/utils/ms_context.h" | ||||
| #include "backend/kernel_compiler/oplib/oplib.h" | #include "backend/kernel_compiler/oplib/oplib.h" | ||||
| @@ -79,8 +79,7 @@ bool CreateSessionAndGraphRunner() { | |||||
| std::shared_ptr<FuncGraph> ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) { | std::shared_ptr<FuncGraph> ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) { | ||||
| try { | try { | ||||
| auto anf_graph = | |||||
| lite::AnfConverter::RunAnfConverter(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize()); | |||||
| auto anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize()); | |||||
| return anf_graph; | return anf_graph; | ||||
| } catch (std::exception &e) { | } catch (std::exception &e) { | ||||
| MS_LOG(ERROR) << "Load MindIR failed."; | MS_LOG(ERROR) << "Load MindIR failed."; | ||||
| @@ -364,6 +363,7 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { | |||||
| Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { | Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { | ||||
| RegAllOp(); | RegAllOp(); | ||||
| Py_Initialize(); | |||||
| auto func_graph = ConvertMindIrToFuncGraph(model_data); | auto func_graph = ConvertMindIrToFuncGraph(model_data); | ||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; | MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include "utils/load_onnx/anf_converter.h" | |||||
| #include "load_mindir/load_model.h" | |||||
| #include "backend/session/session_basic.h" | #include "backend/session/session_basic.h" | ||||
| #include "backend/session/session_factory.h" | #include "backend/session/session_factory.h" | ||||
| #include "backend/session/executor_manager.h" | #include "backend/session/executor_manager.h" | ||||
| @@ -117,9 +117,9 @@ Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::m | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::shared_ptr<FuncGraph> anf_graph; | std::shared_ptr<FuncGraph> anf_graph; | ||||
| Py_Initialize(); | |||||
| try { | try { | ||||
| anf_graph = | |||||
| lite::AnfConverter::RunAnfConverter(static_cast<const char *>(model_data.Data()), model_data.DataSize()); | |||||
| anf_graph = ConvertStreamToFuncGraph(static_cast<const char *>(model_data.Data()), model_data.DataSize()); | |||||
| } catch (std::exception &e) { | } catch (std::exception &e) { | ||||
| MS_LOG(ERROR) << "Inference LoadModel failed"; | MS_LOG(ERROR) << "Inference LoadModel failed"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -290,9 +290,10 @@ Status MsModel::FinalizeEnv() { | |||||
| } | } | ||||
| std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) { | std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) { | ||||
| Py_Initialize(); | |||||
| MS_EXCEPTION_IF_NULL(model_buf); | MS_EXCEPTION_IF_NULL(model_buf); | ||||
| try { | try { | ||||
| auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); | |||||
| auto anf_graph = ConvertStreamToFuncGraph(model_buf, size); | |||||
| return anf_graph; | return anf_graph; | ||||
| } catch (std::exception &e) { | } catch (std::exception &e) { | ||||
| MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what(); | MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what(); | ||||
| @@ -74,6 +74,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") | .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") | ||||
| .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), | .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), | ||||
| py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") | py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") | ||||
| .def("convert_funcgraph_to_mindir", &ExecutorPy::ConvertFuncGraphToMindIR, py::arg("graph"), | |||||
| "Convert FuncGraph to MindIR proto.") | |||||
| .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), | .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), | ||||
| py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") | py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") | ||||
| .def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"), | .def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"), | ||||
| @@ -108,6 +110,7 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); | (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); | ||||
| (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); | (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); | ||||
| (py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load MindIR as Graph."); | |||||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | ||||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | ||||
| @@ -45,6 +45,7 @@ | |||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| #include "pipeline/pynative/pynative_execute.h" | #include "pipeline/pynative/pynative_execute.h" | ||||
| #include "frontend/optimizer/py_pass_manager.h" | #include "frontend/optimizer/py_pass_manager.h" | ||||
| #include "load_mindir/load_model.h" | |||||
| #include "pybind_api/pybind_patch.h" | #include "pybind_api/pybind_patch.h" | ||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| #include "utils/info.h" | #include "utils/info.h" | ||||
| @@ -103,6 +104,16 @@ void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) { | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| py::bytes ExecutorPy::ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr) { | |||||
| std::string proto_str = GetBinaryProtoString(fg_ptr); | |||||
| if (proto_str.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Graph proto is empty."; | |||||
| } | |||||
| return proto_str; | |||||
| } | |||||
| FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::RunLoadMindIR(file_name); } | |||||
| py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) { | py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) { | ||||
| MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); | MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); | ||||
| abstract::AbstractBasePtrList args_spec; | abstract::AbstractBasePtrList args_spec; | ||||
| @@ -82,6 +82,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> { | |||||
| ResourcePtr GetResource(const std::string &phase); | ResourcePtr GetResource(const std::string &phase); | ||||
| FuncGraphPtr GetFuncGraph(const std::string &phase); | FuncGraphPtr GetFuncGraph(const std::string &phase); | ||||
| py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); | py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); | ||||
| py::bytes ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr); | |||||
| compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); | compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); | ||||
| bool HasCompiled(const std::string &phase) const; | bool HasCompiled(const std::string &phase) const; | ||||
| @@ -138,6 +139,7 @@ void ClearResAtexit(); | |||||
| void ReleaseGeTsd(); | void ReleaseGeTsd(); | ||||
| void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); | void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); | ||||
| FuncGraphPtr LoadMindIR(const std::string &file_name); | |||||
| // init and exec dataset sub graph | // init and exec dataset sub graph | ||||
| bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, | bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, | ||||
| @@ -0,0 +1,3 @@ | |||||
| file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||||
| set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) | |||||
| add_library(_mindspore_transform_express_ir_obj OBJECT ${_EXPORTER_IR_SRC_FILES}) | |||||
| @@ -25,33 +25,40 @@ | |||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "proto/onnx.pb.h" | |||||
| #include "proto/mind_ir.pb.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| using FloatPtr = std::shared_ptr<Float>; | using FloatPtr = std::shared_ptr<Float>; | ||||
| using IntPtr = std::shared_ptr<Int>; | using IntPtr = std::shared_ptr<Int>; | ||||
| // anf type to onnx type map | |||||
| static std::unordered_map<int, onnx::TensorProto_DataType> g_data_type_map = { | |||||
| {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, | |||||
| {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, | |||||
| {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, | |||||
| {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, | |||||
| {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, | |||||
| {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, | |||||
| {kObjectTypeString, onnx::TensorProto_DataType_STRING}, | |||||
| // anf type to mindir type map | |||||
| static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = { | |||||
| {kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL}, | |||||
| {kNumberTypeInt8, mind_ir::TensorProto_DataType_INT8}, | |||||
| {kNumberTypeInt16, mind_ir::TensorProto_DataType_INT16}, | |||||
| {kNumberTypeInt32, mind_ir::TensorProto_DataType_INT32}, | |||||
| {kNumberTypeInt64, mind_ir::TensorProto_DataType_INT64}, | |||||
| {kNumberTypeUInt8, mind_ir::TensorProto_DataType_UINT8}, | |||||
| {kNumberTypeUInt16, mind_ir::TensorProto_DataType_UINT16}, | |||||
| {kNumberTypeUInt32, mind_ir::TensorProto_DataType_UINT32}, | |||||
| {kNumberTypeUInt64, mind_ir::TensorProto_DataType_UINT64}, | |||||
| {kNumberTypeFloat16, mind_ir::TensorProto_DataType_FLOAT16}, | |||||
| {kNumberTypeFloat32, mind_ir::TensorProto_DataType_FLOAT}, | |||||
| {kNumberTypeFloat64, mind_ir::TensorProto_DataType_DOUBLE}, | |||||
| {kObjectTypeString, mind_ir::TensorProto_DataType_STRING}, | |||||
| }; | }; | ||||
| static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_int_map = { | |||||
| {8, onnx::TensorProto_DataType_INT8}, | |||||
| {16, onnx::TensorProto_DataType_INT16}, | |||||
| {32, onnx::TensorProto_DataType_INT32}, | |||||
| {64, onnx::TensorProto_DataType_INT64}, | |||||
| static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_map = { | |||||
| {8, mind_ir::TensorProto_DataType_INT8}, | |||||
| {16, mind_ir::TensorProto_DataType_INT16}, | |||||
| {32, mind_ir::TensorProto_DataType_INT32}, | |||||
| {64, mind_ir::TensorProto_DataType_INT64}, | |||||
| }; | }; | ||||
| static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_float_map = { | |||||
| {16, onnx::TensorProto_DataType_FLOAT16}, | |||||
| {32, onnx::TensorProto_DataType_FLOAT}, | |||||
| static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = { | |||||
| {16, mind_ir::TensorProto_DataType_FLOAT16}, | |||||
| {32, mind_ir::TensorProto_DataType_FLOAT}, | |||||
| {64, mind_ir::TensorProto_DataType_FLOAT64}, | |||||
| }; | }; | ||||
| // Can build different builder according to format | // Can build different builder according to format | ||||
| @@ -77,34 +84,34 @@ class IrExportBuilder { | |||||
| void BuildModel(const FuncGraphPtr &func_graph); | void BuildModel(const FuncGraphPtr &func_graph); | ||||
| private: | private: | ||||
| void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); | |||||
| void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); | |||||
| void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); | |||||
| void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto); | |||||
| void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto); | |||||
| std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto); | |||||
| void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto); | |||||
| void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto); | |||||
| void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); | |||||
| void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); | |||||
| void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); | |||||
| void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); | |||||
| void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, | |||||
| void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); | |||||
| void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); | |||||
| void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); | |||||
| void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); | |||||
| void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); | |||||
| std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto); | |||||
| void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); | |||||
| void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto); | |||||
| void SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); | |||||
| void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); | |||||
| void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); | |||||
| void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); | |||||
| void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string); | std::string *const seq_string); | ||||
| void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||||
| void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||||
| void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||||
| void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||||
| void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); | |||||
| void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||||
| void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||||
| void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||||
| void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||||
| void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||||
| void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string); | std::string *const seq_string); | ||||
| void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string); | std::string *const seq_string); | ||||
| onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); | |||||
| onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); | |||||
| onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits); | |||||
| mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id); | |||||
| mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits); | |||||
| mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits); | |||||
| std::string GetNodeName(const AnfNodePtr &node); | std::string GetNodeName(const AnfNodePtr &node); | ||||
| std::string GetUniqueNodeName(const AnfNodePtr &node); | std::string GetUniqueNodeName(const AnfNodePtr &node); | ||||
| std::string GetOpTypeName(const AnfNodePtr &node); | std::string GetOpTypeName(const AnfNodePtr &node); | ||||
| @@ -114,8 +121,8 @@ class IrExportBuilder { | |||||
| void ResetTupleIndex() { shape_index_ = 0; } | void ResetTupleIndex() { shape_index_ = 0; } | ||||
| private: | private: | ||||
| onnx::ModelProto model_; | |||||
| onnx::NodeProto *last_node_{nullptr}; | |||||
| mind_ir::ModelProto model_; | |||||
| mind_ir::NodeProto *last_node_{nullptr}; | |||||
| std::list<FuncGraphPtr> todo_; | std::list<FuncGraphPtr> todo_; | ||||
| std::map<AnfNodePtr, size_t> node_index_map_; | std::map<AnfNodePtr, size_t> node_index_map_; | ||||
| size_t node_index_{0}; | size_t node_index_{0}; | ||||
| @@ -144,13 +151,13 @@ std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| void IrExportBuilder::BuildModelInfo() { | void IrExportBuilder::BuildModelInfo() { | ||||
| model_.set_ir_version(onnx::IR_VERSION_2019_1_22); | |||||
| model_.set_ir_version("0.1.0"); | |||||
| model_.set_producer_name("MindSpore"); | model_.set_producer_name("MindSpore"); | ||||
| model_.set_model_version(1); | |||||
| model_.set_model_version("1.1.0"); | |||||
| } | } | ||||
| void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { | void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { | ||||
| onnx::GraphProto *graph_proto = model_.mutable_graph(); | |||||
| mind_ir::GraphProto *graph_proto = model_.mutable_graph(); | |||||
| graph_proto->set_name(func_graph->ToString()); | graph_proto->set_name(func_graph->ToString()); | ||||
| ResetNodeIndex(); | ResetNodeIndex(); | ||||
| todo_.clear(); | todo_.clear(); | ||||
| @@ -162,7 +169,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||||
| void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { | |||||
| // Export parameters | // Export parameters | ||||
| // 1. parameters should be mapped to ValueInfoProto | // 1. parameters should be mapped to ValueInfoProto | ||||
| // 2. parameters with default value should be mapped to Initializer | // 2. parameters with default value should be mapped to Initializer | ||||
| @@ -172,33 +179,31 @@ void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::Graph | |||||
| BuildNodes(func_graph, graph_proto); | BuildNodes(func_graph, graph_proto); | ||||
| } | } | ||||
| void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||||
| void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { | |||||
| for (auto &item : func_graph->parameters()) { | for (auto &item : func_graph->parameters()) { | ||||
| auto param = item->cast<ParameterPtr>(); | auto param = item->cast<ParameterPtr>(); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; | MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; | ||||
| } | } | ||||
| onnx::ValueInfoProto *input_proto = graph_proto->add_input(); | |||||
| std::string param_name = GetUniqueNodeName(param); | std::string param_name = GetUniqueNodeName(param); | ||||
| input_proto->set_name(param_name); | |||||
| SetValueInfoProto(param, input_proto); | |||||
| if (!param->has_default()) { | |||||
| if (param->has_default()) { | |||||
| MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; | MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; | ||||
| continue; | |||||
| } | |||||
| // Using ONNX initializer to set parameter's default value | |||||
| onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); | |||||
| initializer_proto->set_name(param_name); | |||||
| SetParamToTensorProto(param, initializer_proto); | |||||
| auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param()); | |||||
| if (tensor) { | |||||
| initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); | |||||
| mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter(); | |||||
| parameter_proto->set_name(param_name); | |||||
| SetParamToTensorProto(param, parameter_proto); | |||||
| auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param()); | |||||
| if (tensor) { | |||||
| parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); | |||||
| } | |||||
| } else { | |||||
| mind_ir::ValueInfoProto *input_proto = graph_proto->add_input(); | |||||
| input_proto->set_name(param_name); | |||||
| SetValueInfoProto(param, input_proto); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { | |||||
| mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) { | |||||
| auto iter = g_data_type_map.find(type_id); | auto iter = g_data_type_map.find(type_id); | ||||
| if (iter == g_data_type_map.end()) { | if (iter == g_data_type_map.end()) { | ||||
| MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; | MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; | ||||
| @@ -206,7 +211,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { | |||||
| mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) { | |||||
| auto iter = g_data_bits_int_map.find(bits); | auto iter = g_data_bits_int_map.find(bits); | ||||
| if (iter == g_data_bits_int_map.end()) { | if (iter == g_data_bits_int_map.end()) { | ||||
| MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; | MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; | ||||
| @@ -214,7 +219,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { | |||||
| mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) { | |||||
| auto iter = g_data_bits_float_map.find(bits); | auto iter = g_data_bits_float_map.find(bits); | ||||
| if (iter == g_data_bits_float_map.end()) { | if (iter == g_data_bits_float_map.end()) { | ||||
| MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; | MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; | ||||
| @@ -222,73 +227,70 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) { | |||||
| void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) { | |||||
| if (node == nullptr || value_proto == nullptr) { | if (node == nullptr || value_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; | MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); | MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); | ||||
| SetValueInfoProto(node->Type(), node->Shape(), value_proto); | |||||
| } | |||||
| void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, | |||||
| onnx::ValueInfoProto *const value_proto) { | |||||
| onnx::TypeProto *type_proto = value_proto->mutable_type(); | |||||
| const TypePtr &type = node->Type(); | |||||
| const BaseShapePtr &shape = node->Shape(); | |||||
| if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) { | if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) { | ||||
| auto tensor = type->cast<TensorTypePtr>(); | auto tensor = type->cast<TensorTypePtr>(); | ||||
| auto elem_type = tensor->element(); | auto elem_type = tensor->element(); | ||||
| const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); | const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); | ||||
| type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); | |||||
| mind_ir::TensorProto *tensor_proto = value_proto->add_tensor(); | |||||
| tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id())); | |||||
| if (dims.size() == 0) { | if (dims.size() == 0) { | ||||
| MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; | MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; | ||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); | |||||
| tensor_proto->add_dims(1); | |||||
| } else { | } else { | ||||
| for (const auto &dim : dims) { | for (const auto &dim : dims) { | ||||
| MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; | MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; | ||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | |||||
| tensor_proto->add_dims(dim); | |||||
| } | } | ||||
| } | } | ||||
| } else if (type->isa<Tuple>()) { | } else if (type->isa<Tuple>()) { | ||||
| auto tup_shape = shape->cast<abstract::TupleShapePtr>(); | auto tup_shape = shape->cast<abstract::TupleShapePtr>(); | ||||
| type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); | |||||
| value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); | |||||
| } else if (type->isa<Number>() || type->isa<String>()) { | } else if (type->isa<Number>() || type->isa<String>()) { | ||||
| type_proto->set_denotation(type->type_name()); | |||||
| value_proto->set_denotation(type->type_name()); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; | MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; | ||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { | |||||
| void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { | |||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | ||||
| } | } | ||||
| attr_proto->set_ref_attr_name("tensor:value0"); | attr_proto->set_ref_attr_name("tensor:value0"); | ||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); | |||||
| mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| tensor_proto->set_name("value0"); | tensor_proto->set_name("value0"); | ||||
| auto data = value->cast<tensor::TensorPtr>(); | auto data = value->cast<tensor::TensorPtr>(); | ||||
| tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); | tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); | ||||
| auto dtype = data->data_type(); | auto dtype = data->data_type(); | ||||
| auto shape = data->shape_c(); | auto shape = data->shape_c(); | ||||
| tensor_proto->set_data_type(GetOnnxDataType(dtype)); | |||||
| tensor_proto->set_data_type(GetMindirDataType(dtype)); | |||||
| for (const auto &dim : shape) { | for (const auto &dim : shape) { | ||||
| tensor_proto->add_dims(dim); | tensor_proto->add_dims(dim); | ||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, | void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, | ||||
| onnx::TensorProto *const tensor_proto) { | |||||
| mind_ir::TensorProto *const tensor_proto) { | |||||
| if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) { | if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) { | ||||
| MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); | MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); | ||||
| } | } | ||||
| auto tensor = type->cast<TensorTypePtr>(); | auto tensor = type->cast<TensorTypePtr>(); | ||||
| const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); | const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); | ||||
| tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id())); | |||||
| tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id())); | |||||
| for (const auto &dim : dims) { | for (const auto &dim : dims) { | ||||
| tensor_proto->add_dims(dim); | tensor_proto->add_dims(dim); | ||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { | |||||
| void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { | |||||
| if (param == nullptr || tensor_proto == nullptr) { | if (param == nullptr || tensor_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; | MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; | ||||
| } | } | ||||
| @@ -296,7 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten | |||||
| SetTensorProto(param->Type(), param->Shape(), tensor_proto); | SetTensorProto(param->Type(), param->Shape(), tensor_proto); | ||||
| } | } | ||||
| void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||||
| void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { | |||||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | ||||
| bool is_only_return = true; | bool is_only_return = true; | ||||
| for (const AnfNodePtr &node : nodes) { | for (const AnfNodePtr &node : nodes) { | ||||
| @@ -317,12 +319,12 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt | |||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) { | |||||
| void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { | |||||
| if (node->size() != 2) { | if (node->size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | ||||
| } | } | ||||
| AnfNodePtr arg = node->input(1); | AnfNodePtr arg = node->input(1); | ||||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||||
| mind_ir::ValueInfoProto *output_proto = graph_proto->add_output(); | |||||
| std::string output_name = GetUniqueNodeName(node); | std::string output_name = GetUniqueNodeName(node); | ||||
| output_proto->set_name(output_name); | output_proto->set_name(output_name); | ||||
| last_node_->set_output(0, output_name); | last_node_->set_output(0, output_name); | ||||
| @@ -349,7 +351,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { | |||||
| } | } | ||||
| void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, | void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, | ||||
| onnx::AttributeProto *const attr_proto, std::string *const seq_string) { | |||||
| mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { | |||||
| if (type->isa<Tuple>() && seq_string != nullptr) { | if (type->isa<Tuple>() && seq_string != nullptr) { | ||||
| *seq_string += "Tuple["; | *seq_string += "Tuple["; | ||||
| auto elements = type->cast<TuplePtr>()->elements(); | auto elements = type->cast<TuplePtr>()->elements(); | ||||
| @@ -361,7 +363,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt | |||||
| } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) { | } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) { | ||||
| string shape_name = "shape" + std::to_string(GetTupleIndex()); | string shape_name = "shape" + std::to_string(GetTupleIndex()); | ||||
| *seq_string += shape_name + ","; | *seq_string += shape_name + ","; | ||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| tensor_proto->set_name(shape_name); | tensor_proto->set_name(shape_name); | ||||
| SetTensorProto(type, shape, tensor_proto); | SetTensorProto(type, shape, tensor_proto); | ||||
| } else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) { | } else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) { | ||||
| @@ -371,7 +373,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt | |||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { | |||||
| void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) { | |||||
| // Get shape of cnode | // Get shape of cnode | ||||
| // 1. need to get shape from tuple element | // 1. need to get shape from tuple element | ||||
| // 2. save shape in TensorProto | // 2. save shape in TensorProto | ||||
| @@ -381,13 +383,13 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto | |||||
| auto shape = node->Shape(); | auto shape = node->Shape(); | ||||
| ResetTupleIndex(); | ResetTupleIndex(); | ||||
| std::string seq_string = "shape:"; | std::string seq_string = "shape:"; | ||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| SetShapeToNodeProto(type, shape, attr_proto, &seq_string); | SetShapeToNodeProto(type, shape, attr_proto, &seq_string); | ||||
| attr_proto->set_ref_attr_name(seq_string); | attr_proto->set_ref_attr_name(seq_string); | ||||
| MS_LOG(DEBUG) << "CNode shape: " << seq_string; | MS_LOG(DEBUG) << "CNode shape: " << seq_string; | ||||
| } | } | ||||
| void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { | |||||
| void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { | |||||
| auto inputs_size = node->size(); | auto inputs_size = node->size(); | ||||
| if (inputs_size < 1) { | if (inputs_size < 1) { | ||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | ||||
| @@ -403,7 +405,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g | |||||
| } | } | ||||
| // Build cnode | // Build cnode | ||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| mind_ir::NodeProto *node_proto = graph_proto->add_node(); | |||||
| std::string output_name = GetUniqueNodeName(node); | std::string output_name = GetUniqueNodeName(node); | ||||
| node_proto->add_output(output_name); | node_proto->add_output(output_name); | ||||
| node_proto->set_name(output_name); | node_proto->set_name(output_name); | ||||
| @@ -421,7 +423,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g | |||||
| auto prim = GetValueNode<PrimitivePtr>(op); | auto prim = GetValueNode<PrimitivePtr>(op); | ||||
| for (auto attr : prim->attrs()) { | for (auto attr : prim->attrs()) { | ||||
| MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); | MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); | ||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_name(attr.first); | attr_proto->set_name(attr.first); | ||||
| SetValueToAttributeProto(attr.second, attr_proto); | SetValueToAttributeProto(attr.second, attr_proto); | ||||
| } | } | ||||
| @@ -430,11 +432,11 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g | |||||
| } | } | ||||
| } | } | ||||
| std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) { | |||||
| std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) { | |||||
| std::string node_name = GetUniqueNodeName(node); | std::string node_name = GetUniqueNodeName(node); | ||||
| if (node->isa<ValueNode>()) { | if (node->isa<ValueNode>()) { | ||||
| // When node input is a ValueNode, need to create a Constant Node | // When node input is a ValueNode, need to create a Constant Node | ||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| mind_ir::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->add_output(node_name); | node_proto->add_output(node_name); | ||||
| SetAttributeProto(node, node_proto); | SetAttributeProto(node, node_proto); | ||||
| } | } | ||||
| @@ -478,44 +480,48 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { | |||||
| return node_name; | return node_name; | ||||
| } | } | ||||
| void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) { | |||||
| void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) { | |||||
| if (node == nullptr || node_proto == nullptr) { | if (node == nullptr || node_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; | MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; | ||||
| } | } | ||||
| auto value = node->cast<ValueNodePtr>()->value(); | auto value = node->cast<ValueNodePtr>()->value(); | ||||
| node_proto->set_op_type("Constant"); | node_proto->set_op_type("Constant"); | ||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_name("value"); | attr_proto->set_name("value"); | ||||
| MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); | MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); | ||||
| SetValueToAttributeProto(value, attr_proto); | SetValueToAttributeProto(value, attr_proto); | ||||
| } | } | ||||
| void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { | |||||
| void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { | |||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | ||||
| } | } | ||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); | |||||
| mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| if (value->isa<Int>()) { | if (value->isa<Int>()) { | ||||
| attr_proto->set_ref_attr_name("type:value0"); | attr_proto->set_ref_attr_name("type:value0"); | ||||
| tensor_proto->set_name("value0"); | tensor_proto->set_name("value0"); | ||||
| auto int_value = value->cast<IntPtr>(); | auto int_value = value->cast<IntPtr>(); | ||||
| tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); | |||||
| tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); | |||||
| } else if (value->isa<Float>()) { | } else if (value->isa<Float>()) { | ||||
| attr_proto->set_ref_attr_name("type:value0"); | attr_proto->set_ref_attr_name("type:value0"); | ||||
| tensor_proto->set_name("value0"); | tensor_proto->set_name("value0"); | ||||
| auto float_value = value->cast<FloatPtr>(); | auto float_value = value->cast<FloatPtr>(); | ||||
| tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); | |||||
| tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); | |||||
| } else if (value->isa<Bool>()) { | |||||
| attr_proto->set_ref_attr_name("type:value0"); | |||||
| tensor_proto->set_name("value0"); | |||||
| tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL); | |||||
| } else if (value->isa<TensorType>()) { | } else if (value->isa<TensorType>()) { | ||||
| attr_proto->set_ref_attr_name("type:tensor0"); | attr_proto->set_ref_attr_name("type:tensor0"); | ||||
| tensor_proto->set_name("tensor0"); | tensor_proto->set_name("tensor0"); | ||||
| auto elem_type = value->cast<TensorTypePtr>()->element(); | auto elem_type = value->cast<TensorTypePtr>()->element(); | ||||
| if (elem_type->isa<Int>()) { | if (elem_type->isa<Int>()) { | ||||
| auto int_value = elem_type->cast<IntPtr>(); | auto int_value = elem_type->cast<IntPtr>(); | ||||
| tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); | |||||
| tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); | |||||
| } else if (elem_type->isa<Float>()) { | } else if (elem_type->isa<Float>()) { | ||||
| auto float_value = elem_type->cast<FloatPtr>(); | auto float_value = elem_type->cast<FloatPtr>(); | ||||
| tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); | |||||
| tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); | MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); | ||||
| } | } | ||||
| @@ -524,18 +530,18 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri | |||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { | |||||
| void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { | |||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | ||||
| } | } | ||||
| if (value->isa<StringImm>() || value->isa<Scalar>()) { | if (value->isa<StringImm>() || value->isa<Scalar>()) { | ||||
| SetScalarToAttributeProto(value, attr_proto); | |||||
| SetScalarToAttributeProto_ir(value, attr_proto); | |||||
| } else if (value->isa<Number>() || value->isa<TensorType>()) { | } else if (value->isa<Number>() || value->isa<TensorType>()) { | ||||
| SetTypeToAttributeProto(value, attr_proto); | SetTypeToAttributeProto(value, attr_proto); | ||||
| } else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) { | |||||
| } else if (value->isa<ValueSequeue>()) { | |||||
| ResetTupleIndex(); | ResetTupleIndex(); | ||||
| std::string seq_string = "scalar:"; | std::string seq_string = "scalar:"; | ||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); | |||||
| SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string); | SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string); | ||||
| attr_proto->set_ref_attr_name(seq_string); | attr_proto->set_ref_attr_name(seq_string); | ||||
| MS_LOG(DEBUG) << "Attr string: " << seq_string; | MS_LOG(DEBUG) << "Attr string: " << seq_string; | ||||
| @@ -549,74 +555,102 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr | |||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { | |||||
| if (value == nullptr || attr_proto == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | |||||
| } | |||||
| void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { | |||||
| attr_proto->set_ref_attr_name("scalar:value0"); | attr_proto->set_ref_attr_name("scalar:value0"); | ||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| SetScalarToProto(value, tensor_proto, "value0"); | |||||
| if (value->isa<StringImm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING); | |||||
| attr_proto->set_s(GetValue<std::string>(value)); | |||||
| } else if (value->isa<BoolImm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); | |||||
| attr_proto->set_i(GetValue<bool>(value)); | |||||
| } else if (value->isa<Int8Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8); | |||||
| attr_proto->set_i(value->cast<Int8ImmPtr>()->value()); | |||||
| } else if (value->isa<Int16Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16); | |||||
| attr_proto->set_i(value->cast<Int16ImmPtr>()->value()); | |||||
| } else if (value->isa<Int32Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32); | |||||
| attr_proto->set_i(value->cast<Int32ImmPtr>()->value()); | |||||
| } else if (value->isa<Int64Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64); | |||||
| attr_proto->set_i(value->cast<Int64ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt8Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8); | |||||
| attr_proto->set_i(value->cast<UInt8ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt16Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16); | |||||
| attr_proto->set_i(value->cast<UInt16ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt32Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32); | |||||
| attr_proto->set_i(value->cast<UInt32ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt64Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64); | |||||
| attr_proto->set_i(value->cast<UInt64ImmPtr>()->value()); | |||||
| } else if (value->isa<FP32Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); | |||||
| attr_proto->set_f(GetValue<float>(value)); | |||||
| } else if (value->isa<FP64Imm>()) { | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); | |||||
| attr_proto->set_d(GetValue<double>(value)); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); | |||||
| } | |||||
| } | } | ||||
| void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, | |||||
| const std::string &value_name) { | |||||
| if (value == nullptr || tensor_proto == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; | |||||
| } | |||||
| tensor_proto->set_name(value_name); | |||||
| void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { | |||||
| if (value->isa<StringImm>()) { | if (value->isa<StringImm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); | |||||
| tensor_proto->add_string_data(GetValue<std::string>(value)); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING); | |||||
| attr_proto->add_strings(GetValue<std::string>(value)); | |||||
| } else if (value->isa<BoolImm>()) { | } else if (value->isa<BoolImm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); | |||||
| tensor_proto->add_int32_data(GetValue<bool>(value)); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); | |||||
| attr_proto->add_ints(GetValue<bool>(value)); | |||||
| } else if (value->isa<Int8Imm>()) { | } else if (value->isa<Int8Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); | |||||
| tensor_proto->add_int32_data(value->cast<Int8ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8); | |||||
| attr_proto->add_ints(value->cast<Int8ImmPtr>()->value()); | |||||
| } else if (value->isa<Int16Imm>()) { | } else if (value->isa<Int16Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); | |||||
| tensor_proto->add_int32_data(value->cast<Int16ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16); | |||||
| attr_proto->add_ints(value->cast<Int16ImmPtr>()->value()); | |||||
| } else if (value->isa<Int32Imm>()) { | } else if (value->isa<Int32Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); | |||||
| tensor_proto->add_int32_data(value->cast<Int32ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32); | |||||
| attr_proto->add_ints(value->cast<Int32ImmPtr>()->value()); | |||||
| } else if (value->isa<Int64Imm>()) { | } else if (value->isa<Int64Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); | |||||
| tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64); | |||||
| attr_proto->add_ints(value->cast<Int64ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt8Imm>()) { | } else if (value->isa<UInt8Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); | |||||
| tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8); | |||||
| attr_proto->add_ints(value->cast<UInt8ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt16Imm>()) { | } else if (value->isa<UInt16Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); | |||||
| tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16); | |||||
| attr_proto->add_ints(value->cast<UInt16ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt32Imm>()) { | } else if (value->isa<UInt32Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); | |||||
| tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32); | |||||
| attr_proto->add_ints(value->cast<UInt32ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt64Imm>()) { | } else if (value->isa<UInt64Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); | |||||
| tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value()); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64); | |||||
| attr_proto->add_ints(value->cast<UInt64ImmPtr>()->value()); | |||||
| } else if (value->isa<FP32Imm>()) { | } else if (value->isa<FP32Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); | |||||
| tensor_proto->add_float_data(GetValue<float>(value)); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); | |||||
| attr_proto->add_floats(GetValue<float>(value)); | |||||
| } else if (value->isa<FP64Imm>()) { | } else if (value->isa<FP64Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); | |||||
| tensor_proto->add_double_data(GetValue<double>(value)); | |||||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); | |||||
| attr_proto->add_doubles(GetValue<double>(value)); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); | MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); | ||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string) { | std::string *const seq_string) { | ||||
| string value_name = "value" + std::to_string(GetTupleIndex()); | string value_name = "value" + std::to_string(GetTupleIndex()); | ||||
| if (seq_string != nullptr) { | if (seq_string != nullptr) { | ||||
| *seq_string += value_name + ","; | *seq_string += value_name + ","; | ||||
| } | } | ||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| SetScalarToProto(value, tensor_proto, value_name); | |||||
| SetScalarToAttributeProto_irs(value, attr_proto); | |||||
| } | } | ||||
| void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, | |||||
| mind_ir::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string) { | std::string *const seq_string) { | ||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; | ||||
| @@ -625,6 +659,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, | |||||
| *seq_string += "Tuple["; | *seq_string += "Tuple["; | ||||
| const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>(); | const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>(); | ||||
| if (tuple_value->value().size() == 0) { | if (tuple_value->value().size() == 0) { | ||||
| *seq_string += "],"; | |||||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; | MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -640,6 +675,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, | |||||
| *seq_string += "List["; | *seq_string += "List["; | ||||
| const ValueListPtr &list_value = value->cast<ValueListPtr>(); | const ValueListPtr &list_value = value->cast<ValueListPtr>(); | ||||
| if (list_value->value().size() == 0) { | if (list_value->value().size() == 0) { | ||||
| *seq_string += "],"; | |||||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; | MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -1,3 +0,0 @@ | |||||
| file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||||
| set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) | |||||
| add_library(_mindspore_transform_onnx_obj OBJECT ${_ONNX_SRC_FILES}) | |||||
| @@ -5,11 +5,5 @@ if (NOT ENABLE_GE) | |||||
| list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_GE_SRC_FILES}) | list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_GE_SRC_FILES}) | ||||
| endif () | endif () | ||||
| file(GLOB_RECURSE _UTILS_LITE_SRC_FILES | |||||
| ./load_onnx/anf_converter.cc | |||||
| ./load_onnx/anf_model_parser.cc | |||||
| ) | |||||
| list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_LITE_SRC_FILES}) | |||||
| set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS) | set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS) | ||||
| add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST}) | add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST}) | ||||
| @@ -1,134 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/load_onnx/anf_converter.h" | |||||
| #include <fcntl.h> | |||||
| #include <fstream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "utils/load_onnx/anf_model_parser.h" | |||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||||
| #include "proto/onnx.pb.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| const char WHITESPACE[] = "\t\n\v\f\r "; | |||||
| const int FLAG_PREFIX_LEN = 2; | |||||
| void AnfConverter::Trim(std::string *input) { | |||||
| if (input == nullptr) { | |||||
| return; | |||||
| } | |||||
| if (input->empty()) { | |||||
| return; | |||||
| } | |||||
| input->erase(0, input->find_first_not_of(WHITESPACE)); | |||||
| input->erase(input->find_last_not_of(WHITESPACE) + 1); | |||||
| } | |||||
| int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string fileType) { | |||||
| if (modelFile.size() > fileType.size()) { | |||||
| if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) { | |||||
| return 0; | |||||
| } else { | |||||
| return 1; | |||||
| } | |||||
| } else { | |||||
| return 1; | |||||
| } | |||||
| } | |||||
| bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { | |||||
| std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | |||||
| if (modelFile.size() > PATH_MAX) { | |||||
| MS_LOG(DEBUG) << "file path " << modelFile << " is too long."; | |||||
| return false; | |||||
| } | |||||
| char real_path[PATH_MAX + 1] = {0}; | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| if (nullptr == _fullpath(real_path, modelFile.c_str(), PATH_MAX)) { | |||||
| MS_LOG(DEBUG) << modelFile << " does not exit."; | |||||
| return false; | |||||
| } | |||||
| #else | |||||
| if (nullptr == realpath(modelFile.c_str(), real_path)) { | |||||
| MS_LOG(DEBUG) << modelFile << " does not exit."; | |||||
| return false; | |||||
| } | |||||
| #endif | |||||
| int fd = open(real_path, O_RDONLY); | |||||
| if (fd < 0) { | |||||
| MS_LOG(EXCEPTION) << "failed to open file"; | |||||
| } | |||||
| google::protobuf::io::FileInputStream input(fd); | |||||
| google::protobuf::io::CodedInputStream code_input(&input); | |||||
| code_input.SetTotalBytesLimit(INT_MAX, 536870912); | |||||
| bool ret = onnx_model->ParseFromCodedStream(&code_input); | |||||
| if (!ret) { | |||||
| MS_LOG(ERROR) << "load onnx file failed"; | |||||
| return false; | |||||
| } | |||||
| (void)close(fd); | |||||
| MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; | |||||
| return true; | |||||
| } | |||||
| std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const std::string &file_path) { | |||||
| std::string modelFile; | |||||
| std::string tmp = file_path; | |||||
| Trim(&tmp); | |||||
| const std::string flagItem(tmp); | |||||
| size_t pos = flagItem.find_first_of("="); | |||||
| if (pos == std::string::npos) { | |||||
| MS_LOG(ERROR) << "Trans data not support input format!"; | |||||
| } else { | |||||
| modelFile = flagItem.substr(pos + 1); | |||||
| std::cout << "input protobuf file path is: " << modelFile << std::endl; | |||||
| } | |||||
| if (ValidateFileStr(modelFile, ".pb") != 0) { | |||||
| MS_LOG(EXCEPTION) << "INPUT ILLEGAL: modelFile must be *.pb"; | |||||
| } | |||||
| onnx::ModelProto model_; | |||||
| ReadOnnxFromBinary(modelFile, &model_); | |||||
| MSANFModelParser model_parser; | |||||
| FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); | |||||
| return dstgraph_ptr; | |||||
| } | |||||
| std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const char *buf, const size_t buf_size) { | |||||
| Py_Initialize(); | |||||
| MS_EXCEPTION_IF_NULL(buf); | |||||
| std::string str((const char *)buf, buf_size); | |||||
| onnx::ModelProto model_; | |||||
| if (!model_.ParseFromString(str)) { | |||||
| MS_LOG(EXCEPTION) << "Parse model from buffer fail!"; | |||||
| } | |||||
| MSANFModelParser model_parser; | |||||
| FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); | |||||
| return dstgraph_ptr; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,693 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/load_onnx/anf_model_parser.h" | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <stack> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "ir/param_info.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "proto/onnx.pb.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "utils/shape_utils.h" | |||||
| using std::string; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| static constexpr char kConstantValueNode[] = "Constant"; | |||||
| static constexpr char kCNodeShapeAttr[] = "shape"; | |||||
| static constexpr char kCNodeShape1Attr[] = "shape1"; | |||||
| static constexpr char kCNodeShape2Attr[] = "shape2"; | |||||
| enum ParseForm : int { | |||||
| FORM_PARSE_TYPE = 0, | |||||
| FORM_PARSE_SCALAR = 1, | |||||
| FORM_PARSE_TENSOR = 2, | |||||
| }; | |||||
| static std::map<std::string, ParseForm> kParseTypeSwitchMap{ | |||||
| {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; | |||||
| static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | |||||
| {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, | |||||
| {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, | |||||
| {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, | |||||
| {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, | |||||
| {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, | |||||
| {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, | |||||
| {onnx::TensorProto_DataType_STRING, kObjectTypeString}, | |||||
| }; | |||||
| template <typename T, typename P> | |||||
| std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) { | |||||
| std::stack<std::string> rules; | |||||
| std::stack<P> value; | |||||
| int count = 0; | |||||
| for (size_t i = 0; i < str.length(); i++) { | |||||
| if (str[i] == '[') { | |||||
| rules.push("["); | |||||
| } else if (str[i] == ']') { | |||||
| // rules | |||||
| std::vector<P> vec; | |||||
| while (rules.top() != "[") { | |||||
| rules.pop(); | |||||
| vec.push_back(value.top()); | |||||
| value.pop(); | |||||
| } | |||||
| // pop "[" | |||||
| rules.pop(); | |||||
| // make tuple for names | |||||
| std::string res = "dummy"; | |||||
| // make tuple for values | |||||
| reverse(vec.begin(), vec.end()); | |||||
| auto vt = std::make_shared<T>(vec); | |||||
| if (rules.empty() && value.empty()) { | |||||
| return vt; | |||||
| } | |||||
| rules.push(res); | |||||
| value.push(vt); | |||||
| } else if (str[i] == ',') { | |||||
| continue; | |||||
| } else { | |||||
| count++; | |||||
| if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { | |||||
| auto value_name = str.substr(i - count + 1, count); | |||||
| value.push(kv.at(value_name)); | |||||
| rules.push(value_name); | |||||
| count = 0; | |||||
| } | |||||
| } | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name, | |||||
| const std::unordered_map<string, ValuePtr> &kv) { | |||||
| std::string str = attr_name; | |||||
| auto replace = [&](const string &orgStr, const string &newStr) { | |||||
| std::string::size_type pos(0); | |||||
| while ((pos = str.find(orgStr)) != std::string::npos) { | |||||
| str.replace(pos, orgStr.length(), newStr); | |||||
| } | |||||
| return str; | |||||
| }; | |||||
| // remove "scalar:" | |||||
| str = replace("scalar:", ""); | |||||
| // remove "Tuple" | |||||
| str = replace("Tuple", ""); | |||||
| // remove "List" | |||||
| str = replace("List", ""); | |||||
| auto result = ParserAttr<ValueTuple>(str, kv); | |||||
| if (!result) { | |||||
| return {}; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::shared_ptr<abstract::AbstractTuple> ParserAttrShape( | |||||
| const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) { | |||||
| std::string str = attr_name; | |||||
| auto replace = [&](const string &orgStr, const string &newStr) { | |||||
| std::string::size_type pos(0); | |||||
| while ((pos = str.find(orgStr)) != std::string::npos) { | |||||
| str.replace(pos, orgStr.length(), newStr); | |||||
| } | |||||
| return str; | |||||
| }; | |||||
| // remove "scalar:" | |||||
| str = replace("shape:", ""); | |||||
| // remove "Tuple" | |||||
| str = replace("Tuple", ""); | |||||
| // remove "List" | |||||
| str = replace("List", ""); | |||||
| auto result = ParserAttr<abstract::AbstractTuple>(str, kv); | |||||
| if (!result) { | |||||
| return {}; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | |||||
| bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!value_proto.has_type() || !value_proto.has_name()) { | |||||
| MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; | |||||
| return false; | |||||
| } | |||||
| node->set_name(value_proto.name()); | |||||
| const auto &type_proto = value_proto.type(); | |||||
| if (!type_proto.has_tensor_type()) { | |||||
| MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! "; | |||||
| return false; | |||||
| } | |||||
| const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); | |||||
| if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { | |||||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; | |||||
| return false; | |||||
| } | |||||
| const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); | |||||
| ShapeVector shape; | |||||
| for (int i = 0; i < tensor_shape.dim_size(); ++i) { | |||||
| shape.push_back(tensor_shape.dim(i).dim_value()); | |||||
| } | |||||
| if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; | |||||
| return false; | |||||
| } | |||||
| tensor::TensorPtr tensor_info = | |||||
| std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||||
| node->set_abstract(tensor_abstract); | |||||
| if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | |||||
| const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; | |||||
| std::string initial_data = initialize_proto.raw_data(); | |||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data_buf); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size()); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||||
| } | |||||
| node->set_default_param(tensor_info); | |||||
| } | |||||
| anfnode_build_map_[value_proto.name()] = node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); | |||||
| for (int i = 0; i < importProto.initializer_size(); ++i) { | |||||
| const onnx::TensorProto &initializer_proto = importProto.initializer(i); | |||||
| if (!initializer_proto.has_name()) { | |||||
| MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; | |||||
| return false; | |||||
| } | |||||
| default_para_map_[initializer_proto.name()] = initializer_proto; | |||||
| } | |||||
| MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); | |||||
| for (int i = 0; i < importProto.input_size(); ++i) { | |||||
| const onnx::ValueInfoProto &input_proto = importProto.input(i); | |||||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { | |||||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; | |||||
| return false; | |||||
| } | |||||
| prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| return true; | |||||
| } | |||||
| ValuePtr MSANFModelParser::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| switch (attr_tensor_type) { | |||||
| case onnx::TensorProto_DataType_STRING: { | |||||
| return ParseAttrInScalar_string_string(attr_tensor); | |||||
| } | |||||
| case onnx::TensorProto_DataType_INT32: { | |||||
| return ParseAttrInScalar_int32_int32(attr_tensor); | |||||
| } | |||||
| case onnx::TensorProto_DataType_INT64: { | |||||
| return ParseAttrInScalar_int64_int64(attr_tensor); | |||||
| } | |||||
| case onnx::TensorProto_DataType_UINT64: { | |||||
| return ParseAttrInScalar_uint64_uint64(attr_tensor); | |||||
| } | |||||
| case onnx::TensorProto_DataType_FLOAT: { | |||||
| return ParseAttrInScalar_float_float(attr_tensor); | |||||
| } | |||||
| case onnx::TensorProto_DataType_DOUBLE: { | |||||
| return ParseAttrInScalar_double_double(attr_tensor); | |||||
| } | |||||
| case onnx::TensorProto_DataType_BOOL: { | |||||
| return ParseAttrInScalar_int32_bool(attr_tensor); | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||||
| return {}; | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; | |||||
| return false; | |||||
| } | |||||
| bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| const std::string &attr_name = attr_proto.name(); | |||||
| if (!attr_proto.has_ref_attr_name()) { | |||||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| const std::string &ref_attr_name = attr_proto.ref_attr_name(); | |||||
| string type; | |||||
| std::size_t pos(0); | |||||
| if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("scalar:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("type:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("tensor:").length() - 1); | |||||
| } | |||||
| std::unordered_map<std::string, ValuePtr> kv; | |||||
| for (int i = 0; i < attr_proto.tensors_size(); i++) { | |||||
| const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); | |||||
| switch (kParseTypeSwitchMap[type]) { | |||||
| case FORM_PARSE_TYPE: { | |||||
| ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_SCALAR: { | |||||
| auto res = ObtainCNodeAttrInScalarForm(attr_tensor); | |||||
| kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res)); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_TENSOR: { | |||||
| ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { | |||||
| if (kv.size() == 1) { | |||||
| auto iter = kv.begin(); | |||||
| prim->AddAttr(attr_name, iter->second); | |||||
| } else { | |||||
| auto res = ParserScalarAttrValue(ref_attr_name, kv); | |||||
| prim->AddAttr(attr_name, res); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| ShapeVector shape; | |||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | |||||
| shape.push_back(attr_tensor.dims(i)); | |||||
| } | |||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||||
| } | |||||
| auto new_value_node = NewValueNode(MakeValue(tensor_info)); | |||||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||||
| new_value_node->set_abstract(tensor_abstract); | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainValueNodeInScalarForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| ValuePtr value_ptr = nullptr; | |||||
| switch (attr_tensor_type) { | |||||
| case onnx::TensorProto_DataType_INT32: { | |||||
| std::vector<int64_t> add_data; | |||||
| for (int i = 0; i < attr_tensor.int32_data_size(); ++i) { | |||||
| add_data.push_back(attr_tensor.int32_data(i)); | |||||
| } | |||||
| if (add_data.size() == 1) { | |||||
| value_ptr = MakeValue(add_data[0]); | |||||
| } else if (!add_data.empty()) { | |||||
| value_ptr = MakeValue<std::vector<int64_t>>(add_data); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case onnx::TensorProto_DataType_FLOAT: { | |||||
| std::vector<float> add_data; | |||||
| for (int i = 0; i < attr_tensor.float_data_size(); ++i) { | |||||
| add_data.push_back(attr_tensor.float_data(i)); | |||||
| } | |||||
| if (add_data.size() == 1) { | |||||
| value_ptr = MakeValue(add_data[0]); | |||||
| } else if (!add_data.empty()) { | |||||
| value_ptr = MakeValue<std::vector<float>>(add_data); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case onnx::TensorProto_DataType_UNDEFINED: { | |||||
| std::vector<ValuePtr> elems; | |||||
| value_ptr = std::make_shared<ValueTuple>(elems); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||||
| return false; | |||||
| } | |||||
| auto new_value_node = NewValueNode(value_ptr); | |||||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||||
| new_value_node->set_abstract(value_ptr->ToAbstract()); | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | |||||
| return false; | |||||
| } | |||||
| auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||||
| new_value_node->set_abstract(abs_type); | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name, | |||||
| const onnx::AttributeProto &attr_proto) { | |||||
| if (!attr_proto.has_ref_attr_name()) { | |||||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| const std::string &ref_attr_name = attr_proto.ref_attr_name(); | |||||
| string type; | |||||
| std::size_t pos(0); | |||||
| if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("scalar:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("type:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("tensor:").length() - 1); | |||||
| } | |||||
| std::unordered_map<std::string, ValuePtr> kv; | |||||
| for (int i = 0; i < attr_proto.tensors_size(); i++) { | |||||
| const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); | |||||
| auto attr_name = attr_tensor.name(); | |||||
| switch (kParseTypeSwitchMap[type]) { | |||||
| case FORM_PARSE_TYPE: { | |||||
| return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); | |||||
| } | |||||
| case FORM_PARSE_SCALAR: { | |||||
| auto res = ObtainCNodeAttrInScalarForm(attr_tensor); | |||||
| kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res)); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_TENSOR: { | |||||
| return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| ValueNodePtr new_value_node; | |||||
| if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { | |||||
| if (kv.size() == 1) { | |||||
| auto iter = kv.begin(); | |||||
| new_value_node = NewValueNode(iter->second); | |||||
| new_value_node->set_abstract(iter->second->ToAbstract()); | |||||
| } else { | |||||
| auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); | |||||
| new_value_node = NewValueNode(value_ptr); | |||||
| new_value_node->set_abstract(value_ptr->ToAbstract()); | |||||
| } | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||||
| const std::string &value_node_name = node_proto.output(0); | |||||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | |||||
| if (!attr_proto.has_ref_attr_name()) { | |||||
| MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| return GetAttrValueForValueNode(value_node_name, attr_proto); | |||||
| } | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode( | |||||
| const onnx::AttributeProto &attr_proto) { | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> kv; | |||||
| for (int i = 0; i < attr_proto.tensors_size(); ++i) { | |||||
| ShapeVector shape_vec; | |||||
| const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); | |||||
| for (int j = 0; j < attr_tensor.dims_size(); ++j) { | |||||
| shape_vec.push_back(attr_tensor.dims(j)); | |||||
| } | |||||
| tensor::TensorPtr tensor_info = | |||||
| std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||||
| auto abstract = tensor_info->ToAbstract(); | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract)); | |||||
| } | |||||
| return kv; | |||||
| } | |||||
| CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::NodeProto &node_proto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| if (!node_proto.has_op_type()) { | |||||
| MS_LOG(ERROR) << "Get CNode op_type failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| const std::string &node_name = node_proto.output(0); | |||||
| const std::string &fullname_with_scope = node_proto.domain(); | |||||
| const std::string &node_type = node_proto.op_type(); | |||||
| PrimitivePtr prim = std::make_shared<Primitive>(node_type); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| prim->set_instance_name(node_type); | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> kv; | |||||
| string shape_ref_attr_name; | |||||
| for (int i = 0; i < node_proto.attribute_size(); ++i) { | |||||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(i); | |||||
| if (attr_proto.ref_attr_name().find("shape:") != string::npos) { | |||||
| shape_ref_attr_name = attr_proto.ref_attr_name(); | |||||
| kv = GetAbstractForCNode(attr_proto); | |||||
| continue; | |||||
| } | |||||
| if (!GetAttrValueForCNode(prim, attr_proto)) { | |||||
| MS_LOG(ERROR) << "Get CNode attr failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(prim)); | |||||
| for (int i = 0; i < node_proto.input_size(); ++i) { | |||||
| const std::string &input_name = node_proto.input(i); | |||||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | |||||
| MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; | |||||
| return nullptr; | |||||
| } | |||||
| inputs.push_back(anfnode_build_map_[input_name]); | |||||
| } | |||||
| CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| if (0 == kv.size()) { | |||||
| AbstractBasePtrList elem; | |||||
| for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { | |||||
| elem.push_back(cnode_ptr->input(index)->abstract()); | |||||
| } | |||||
| cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| } else if (1 == kv.size()) { | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin(); | |||||
| cnode_ptr->set_abstract(iter->second); | |||||
| } else { | |||||
| auto abstract = ParserAttrShape(shape_ref_attr_name, kv); | |||||
| cnode_ptr->set_abstract(abstract); | |||||
| } | |||||
| cnode_ptr->set_fullname_with_scope(fullname_with_scope); | |||||
| anfnode_build_map_[node_name] = cnode_ptr; | |||||
| return cnode_ptr; | |||||
| } | |||||
| bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const CNodePtr &cnode_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| if (importProto.output_size() > 1) { | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| AbstractBasePtrList elem; | |||||
| for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { | |||||
| const onnx::ValueInfoProto &output_node = importProto.output(out_size); | |||||
| const std::string &out_tuple = output_node.name(); | |||||
| inputs.push_back(anfnode_build_map_[out_tuple]); | |||||
| elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | |||||
| } | |||||
| auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | |||||
| maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | |||||
| inputs.push_back(maketuple_ptr); | |||||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| outputFuncGraph->set_return(return_node); | |||||
| MS_LOG(INFO) << "Construct funcgraph finined, all success."; | |||||
| } else { | |||||
| const onnx::ValueInfoProto &output_node = importProto.output(0); | |||||
| const onnx::TypeProto &output_typeproto = output_node.type(); | |||||
| ShapeVector output_shape; | |||||
| for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { | |||||
| output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); | |||||
| } | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | |||||
| inputs.push_back(cnode_ptr); | |||||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| outputFuncGraph->set_return(return_node); | |||||
| MS_LOG(INFO) << "Construct funcgraph finined, all success!"; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | |||||
| CNodePtr cnode_ptr = nullptr; | |||||
| for (int i = 0; i < importProto.node_size(); ++i) { | |||||
| const onnx::NodeProto &node_proto = importProto.node(i); | |||||
| const std::string &node_type = node_proto.op_type(); | |||||
| if (node_type == kConstantValueNode) { | |||||
| if (!BuildValueNodeForFuncGraph(node_proto)) { | |||||
| MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; | |||||
| return false; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); | |||||
| if (cnode_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | |||||
| MS_EXCEPTION_IF_NULL(debug_info_ptr); | |||||
| if (importProto.has_name()) { | |||||
| debug_info_ptr->set_name(importProto.name()); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "FuncGraph under converting has not name!"; | |||||
| } | |||||
| if (!ImportParametersForGraph(outputFuncGraph, importProto)) { | |||||
| return false; | |||||
| } | |||||
| return ImportNodesForGraph(outputFuncGraph, importProto); | |||||
| } | |||||
| bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||||
| if (!model_proto.has_producer_name()) { | |||||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | |||||
| return false; | |||||
| } | |||||
| producer_name_ = model_proto.producer_name(); | |||||
| MS_LOG(INFO) << "producer_name :" << producer_name_; | |||||
| if (!model_proto.has_model_version()) { | |||||
| MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; | |||||
| return false; | |||||
| } | |||||
| model_version_ = model_proto.model_version(); | |||||
| MS_LOG(INFO) << "producer_version : " << model_version_; | |||||
| if (!model_proto.has_ir_version()) { | |||||
| MS_LOG(ERROR) << "Parse model version from pb file failed!"; | |||||
| return false; | |||||
| } | |||||
| ir_version_ = model_proto.ir_version(); | |||||
| MS_LOG(INFO) << "ir_version :" << ir_version_; | |||||
| return true; | |||||
| } | |||||
| FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) { | |||||
| FuncGraphPtr dstGraph = std::make_shared<FuncGraph>(); | |||||
| MS_EXCEPTION_IF_NULL(dstGraph); | |||||
| if (!MSANFParseModelConfigureInfo(model_proto)) { | |||||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | |||||
| } | |||||
| const onnx::GraphProto &graphBuild = model_proto.graph(); | |||||
| if (!BuildFuncGraph(dstGraph, graphBuild)) { | |||||
| MS_LOG(ERROR) << "Build funcgraph failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; | |||||
| return dstGraph; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,119 @@ | |||||
| syntax = "proto2"; | |||||
| package mind_ir; | |||||
| message AttributeProto { | |||||
| enum AttributeType { | |||||
| UNDEFINED = 0; | |||||
| FLOAT = 1; | |||||
| UINT8 = 2; | |||||
| INT8 = 3; | |||||
| UINT16 = 4; | |||||
| INT16 = 5; | |||||
| INT32 = 6; | |||||
| INT64 = 7; | |||||
| STRING = 8; | |||||
| BOOL = 9; | |||||
| FLOAT16 = 10; | |||||
| DOUBLE = 11; | |||||
| UINT32 = 12; | |||||
| UINT64 = 13; | |||||
| COMPLEX64 = 14; | |||||
| COMPLEX128 = 15; | |||||
| BFLOAT16 = 16; | |||||
| TENSOR = 17; | |||||
| GRAPH = 18; | |||||
| TENSORS = 19; | |||||
| } | |||||
| optional string name = 1; | |||||
| optional float f = 2; | |||||
| optional int64 i = 3; | |||||
| optional double d = 4; | |||||
| optional bytes s = 5; | |||||
| optional TensorProto t = 6; | |||||
| optional GraphProto g = 7; | |||||
| repeated float floats = 8; | |||||
| repeated double doubles = 9; | |||||
| repeated int64 ints = 10; | |||||
| repeated bytes strings = 11; | |||||
| repeated TensorProto tensors = 12; | |||||
| repeated GraphProto graphs = 13; | |||||
| optional string doc_string = 14; | |||||
| optional string ref_attr_name = 15; | |||||
| optional AttributeType type = 16; | |||||
| } | |||||
| message ValueInfoProto { | |||||
| optional string name = 1; | |||||
| repeated TensorProto tensor = 2; | |||||
| optional string doc_string = 3; | |||||
| optional string denotation = 4; | |||||
| } | |||||
| message NodeProto { | |||||
| repeated string input = 1; | |||||
| repeated string output = 2; | |||||
| optional string name = 3; | |||||
| optional string op_type = 4; | |||||
| repeated AttributeProto attribute = 5; | |||||
| optional string doc_string = 6; | |||||
| optional string domain = 7; | |||||
| } | |||||
| message ModelProto { | |||||
| optional string ir_version = 1; | |||||
| optional string producer_name = 2; | |||||
| optional string producer_version = 3; | |||||
| optional string domain = 4; | |||||
| optional string model_version = 5; | |||||
| optional string doc_string = 6; | |||||
| optional GraphProto graph = 7; | |||||
| } | |||||
| message GraphProto { | |||||
| repeated NodeProto node = 1; | |||||
| optional string name = 2; | |||||
| repeated TensorProto parameter = 3; | |||||
| optional string doc_string = 4; | |||||
| repeated ValueInfoProto input = 5; | |||||
| repeated ValueInfoProto output = 6; | |||||
| } | |||||
| message TensorProto { | |||||
| enum DataType { | |||||
| UNDEFINED = 0; | |||||
| // Basic types. | |||||
| FLOAT = 1; // float | |||||
| UINT8 = 2; // uint8_t | |||||
| INT8 = 3; // int8_t | |||||
| UINT16 = 4; // uint16_t | |||||
| INT16 = 5; // int16_t | |||||
| INT32 = 6; // int32_t | |||||
| INT64 = 7; // int64_t | |||||
| STRING = 8; // string | |||||
| BOOL = 9; // bool | |||||
| FLOAT16 = 10; | |||||
| DOUBLE = 11; | |||||
| UINT32 = 12; | |||||
| UINT64 = 13; | |||||
| COMPLEX64 = 14; | |||||
| COMPLEX128 = 15; | |||||
| BFLOAT16 = 16; | |||||
| FLOAT64 = 17; | |||||
| } | |||||
| repeated int64 dims = 1; | |||||
| optional int32 data_type = 2; | |||||
| repeated float float_data = 3; | |||||
| repeated int32 int32_data = 4; | |||||
| repeated bytes string_data = 5; | |||||
| repeated int64 int64_data = 6; | |||||
| optional string name = 7; | |||||
| optional string doc_string = 8; | |||||
| optional bytes raw_data = 9; | |||||
| repeated double double_data = 10; | |||||
| repeated uint64 uint64_data = 11; | |||||
| } | |||||
| @@ -15,6 +15,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "c_ops/*.cc" | "c_ops/*.cc" | ||||
| "ir/*.cc" | "ir/*.cc" | ||||
| "utils/*.cc" | "utils/*.cc" | ||||
| "load_mindir/*.cc" | |||||
| ) | ) | ||||
| set_property(SOURCE ${CORE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_CORE) | set_property(SOURCE ${CORE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_CORE) | ||||
| add_library(mindspore_core STATIC ${CORE_SRC_LIST}) | add_library(mindspore_core STATIC ${CORE_SRC_LIST}) | ||||
| @@ -50,4 +50,5 @@ AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer); | ||||
| REGISTER_PRIMITIVE_C(TensorAdd); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -102,4 +102,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer); | ||||
| REGISTER_PRIMITIVE_C(AvgPool); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -193,4 +193,5 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||||
| Conv2dInferShape(primitive, input_args)->shape()); | Conv2dInferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); | ||||
| REGISTER_PRIMITIVE_C(Conv2D); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,7 +33,7 @@ constexpr auto kPad = "pad"; | |||||
| constexpr auto kPads = "pads"; | constexpr auto kPads = "pads"; | ||||
| constexpr auto kMode = "mode"; | constexpr auto kMode = "mode"; | ||||
| constexpr auto kGroup = "group"; | constexpr auto kGroup = "group"; | ||||
| constexpr auto kOutputChannel = "output_channel"; | |||||
| constexpr auto kOutputChannel = "out_channel"; | |||||
| constexpr auto kPadList = "pad_list"; | constexpr auto kPadList = "pad_list"; | ||||
| constexpr auto kAxis = "axis"; | constexpr auto kAxis = "axis"; | ||||
| @@ -31,4 +31,13 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { | |||||
| auto infer_function = iter->second.impl_; | auto infer_function = iter->second.impl_; | ||||
| return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list); | return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list); | ||||
| } | } | ||||
| OpPrimCRegister &OpPrimCRegister::GetInstance() { | |||||
| static OpPrimCRegister instance; | |||||
| return instance; | |||||
| } | |||||
| std::map<std::string, OpPrimCDefineFunc> OpPrimCRegister::GetPrimCMap() { return op_primc_fns_; } | |||||
| void OpPrimCRegister::SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn) { op_primc_fns_[name] = fn; } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,6 +18,8 @@ | |||||
| #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include <memory> | |||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| @@ -32,5 +34,33 @@ class PrimitiveC : public Primitive { | |||||
| protected: | protected: | ||||
| void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name); | void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name); | ||||
| }; | }; | ||||
| using OpPrimCDefineFunc = std::function<std::shared_ptr<PrimitiveC>()>; | |||||
| class OpPrimCRegister { | |||||
| public: | |||||
| ~OpPrimCRegister() {} | |||||
| static OpPrimCRegister &GetInstance(); | |||||
| std::map<std::string, OpPrimCDefineFunc> GetPrimCMap(); | |||||
| void SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn); | |||||
| private: | |||||
| OpPrimCRegister() {} | |||||
| std::map<std::string, OpPrimCDefineFunc> op_primc_fns_; | |||||
| }; | |||||
| class OpPrimCRegisterHelper { | |||||
| public: | |||||
| OpPrimCRegisterHelper(const std::string &name, const OpPrimCDefineFunc &fn) { | |||||
| OpPrimCRegister::GetInstance().SetPrimCMap(name, fn); | |||||
| } | |||||
| ~OpPrimCRegisterHelper() = default; | |||||
| }; | |||||
| #define REGISTER_PRIMITIVE_C(name) \ | |||||
| std::shared_ptr<PrimitiveC> GetDefaultPrimC##name() { \ | |||||
| auto out = std::make_shared<name>(); \ | |||||
| return out; \ | |||||
| } \ | |||||
| OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | ||||
| @@ -49,4 +49,5 @@ AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| Relu6InferShape(primitive, input_args)->shape()); | Relu6InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer); | REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer); | ||||
| REGISTER_PRIMITIVE_C(Relu6); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,4 +41,5 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer); | ||||
| REGISTER_PRIMITIVE_C(Reshape); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,7 +36,7 @@ class Reshape : public PrimitiveC { | |||||
| AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimTensorAddPtr = std::shared_ptr<Reshape>; | |||||
| using PrimReshapePtr = std::shared_ptr<Reshape>; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_RESHAPE_H_ | #endif // MINDSPORE_CORE_C_OPS_RESHAPE_H_ | ||||
| @@ -75,4 +75,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer); | ||||
| REGISTER_PRIMITIVE_C(Softmax); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -76,4 +76,5 @@ AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer); | ||||
| REGISTER_PRIMITIVE_C(Squeeze); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,854 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "load_mindir/anf_model_parser.h" | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <stack> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include "ir/tensor.h" | |||||
| #include "ir/param_info.h" | |||||
| #include "c_ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "utils/shape_utils.h" | |||||
| using std::string; | |||||
| namespace mindspore { | |||||
| static constexpr char kConstantValueNode[] = "Constant"; | |||||
| static constexpr char kCNodeShapeAttr[] = "shape"; | |||||
| static constexpr char kCNodeShape1Attr[] = "shape1"; | |||||
| static constexpr char kCNodeShape2Attr[] = "shape2"; | |||||
| enum ParseForm : int { | |||||
| FORM_PARSE_TYPE = 0, | |||||
| FORM_PARSE_SCALAR = 1, | |||||
| FORM_PARSE_TENSOR = 2, | |||||
| FORM_PARSE_NONE = 3, | |||||
| FORM_PARSE_UNDEFINE = 4, | |||||
| }; | |||||
| static std::map<std::string, ParseForm> kParseTypeSwitchMap{{"type", FORM_PARSE_TYPE}, | |||||
| {"scalar", FORM_PARSE_SCALAR}, | |||||
| {"tensor", FORM_PARSE_TENSOR}, | |||||
| {"none", FORM_PARSE_NONE}, | |||||
| {"", FORM_PARSE_UNDEFINE}}; | |||||
| static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | |||||
| {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool}, | |||||
| {mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8}, | |||||
| {mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16}, | |||||
| {mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32}, | |||||
| {mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64}, | |||||
| {mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8}, | |||||
| {mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16}, | |||||
| {mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32}, | |||||
| {mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64}, | |||||
| {mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, | |||||
| {mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, | |||||
| {mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64}, | |||||
| {mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, | |||||
| {mind_ir::TensorProto_DataType_STRING, kObjectTypeString}, | |||||
| }; | |||||
| template <typename T, typename P> | |||||
| std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) { | |||||
| std::stack<std::string> rules; | |||||
| std::stack<P> value; | |||||
| int count = 0; | |||||
| for (size_t i = 0; i < str.length(); i++) { | |||||
| if (str[i] == '[') { | |||||
| rules.push("["); | |||||
| } else if (str[i] == ']') { | |||||
| // rules | |||||
| std::vector<P> vec; | |||||
| while (rules.top() != "[") { | |||||
| rules.pop(); | |||||
| vec.push_back(value.top()); | |||||
| value.pop(); | |||||
| } | |||||
| // pop "[" | |||||
| rules.pop(); | |||||
| // make tuple for names | |||||
| std::string res = "dummy"; | |||||
| // make tuple for values | |||||
| reverse(vec.begin(), vec.end()); | |||||
| auto vt = std::make_shared<T>(vec); | |||||
| if (rules.empty() && value.empty()) { | |||||
| return vt; | |||||
| } | |||||
| rules.push(res); | |||||
| value.push(vt); | |||||
| } else if (str[i] == ',') { | |||||
| continue; | |||||
| } else { | |||||
| count++; | |||||
| if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { | |||||
| auto value_name = str.substr(i - count + 1, count); | |||||
| value.push(kv.at(value_name)); | |||||
| rules.push(value_name); | |||||
| count = 0; | |||||
| } | |||||
| } | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| template <typename T> | |||||
| std::shared_ptr<T> ParserScalarAttrValue(const std::string &attr_name, const std::unordered_map<string, ValuePtr> &kv) { | |||||
| std::string str = attr_name; | |||||
| auto replace = [&](const string &orgStr, const string &newStr) { | |||||
| std::string::size_type pos(0); | |||||
| while ((pos = str.find(orgStr)) != std::string::npos) { | |||||
| str.replace(pos, orgStr.length(), newStr); | |||||
| } | |||||
| return str; | |||||
| }; | |||||
| // remove "scalar:" | |||||
| str = replace("scalar:", ""); | |||||
| // remove "Tuple" | |||||
| str = replace("Tuple", ""); | |||||
| // remove "List" | |||||
| str = replace("List", ""); | |||||
| auto result = ParserAttr<T>(str, kv); | |||||
| return result; | |||||
| } | |||||
| std::shared_ptr<abstract::AbstractTuple> ParserAttrShape( | |||||
| const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) { | |||||
| std::string str = attr_name; | |||||
| auto replace = [&](const string &orgStr, const string &newStr) { | |||||
| std::string::size_type pos(0); | |||||
| while ((pos = str.find(orgStr)) != std::string::npos) { | |||||
| str.replace(pos, orgStr.length(), newStr); | |||||
| } | |||||
| return str; | |||||
| }; | |||||
| // remove "scalar:" | |||||
| str = replace("shape:", ""); | |||||
| // remove "Tuple" | |||||
| str = replace("Tuple", ""); | |||||
| // remove "List" | |||||
| str = replace("List", ""); | |||||
| auto result = ParserAttr<abstract::AbstractTuple>(str, kv); | |||||
| return result; | |||||
| } | |||||
| std::string ParseParameterName(const string &name) { | |||||
| string delimiter = ":"; | |||||
| size_t pos(0); | |||||
| if ((pos = name.find(delimiter)) != string::npos) { | |||||
| return name.substr(pos + 1, string::npos - (pos + 1)); | |||||
| } | |||||
| return name; | |||||
| } | |||||
| std::string ParseCNodeName(const string &name) { | |||||
| string delimiter = ":"; | |||||
| size_t pos = name.find(delimiter); | |||||
| size_t end_pos = name.find_last_of(delimiter); | |||||
| if (pos != string::npos && end_pos != string::npos && pos != end_pos) { | |||||
| return name.substr(pos + 1, end_pos - (pos + 1)); | |||||
| } | |||||
| return name; | |||||
| } | |||||
| #define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype) \ | |||||
| ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \ | |||||
| auto value = static_cast<valuetype>(attr_proto.ints(index)); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } \ | |||||
| ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) { \ | |||||
| auto value = static_cast<valuetype>(attr_proto.i()); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } | |||||
| #define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \ | |||||
| auto value = static_cast<valuetype>(attr_proto.type##s(index)); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t) | |||||
| PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool) | |||||
| PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double) | |||||
| PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float) | |||||
| PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string) | |||||
| ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) { | |||||
| auto value = static_cast<string>(attr_proto.s()); | |||||
| return MakeValue<string>(value); | |||||
| } | |||||
| ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) { | |||||
| auto value = static_cast<float>(attr_proto.f()); | |||||
| return MakeValue<float>(value); | |||||
| } | |||||
| ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) { | |||||
| auto value = static_cast<double>(attr_proto.d()); | |||||
| return MakeValue<double>(value); | |||||
| } | |||||
| tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) { | |||||
| ShapeVector shape; | |||||
| for (int i = 0; i < tensor_proto.dims_size(); ++i) { | |||||
| shape.push_back(tensor_proto.dims(i)); | |||||
| } | |||||
| if (!tensor_proto.has_data_type()) { | |||||
| MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!"; | |||||
| return nullptr; | |||||
| } | |||||
| if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!"; | |||||
| return nullptr; | |||||
| } | |||||
| tensor::TensorPtr tensor_info = | |||||
| std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape); | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||||
| return tensor_info; | |||||
| } | |||||
| bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, | |||||
| const mind_ir::TensorProto ¶meter_proto) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!parameter_proto.has_name()) { | |||||
| MS_LOG(ERROR) << "mind_ir TensorProto has no name!"; | |||||
| return false; | |||||
| } | |||||
| string debug_info_name = ParseParameterName(parameter_proto.name()); | |||||
| auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name); | |||||
| node->set_debug_info(debug_info_ptr); | |||||
| node->set_name(parameter_proto.name()); | |||||
| tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto); | |||||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||||
| node->set_abstract(tensor_abstract); | |||||
| std::string initial_data = parameter_proto.raw_data(); | |||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data_buf); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size()); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error for build parameter, errorno " << ret; | |||||
| } | |||||
| node->set_default_param(tensor_info); | |||||
| anfnode_build_map_[parameter_proto.name()] = node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!value_proto.has_name()) { | |||||
| MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!"; | |||||
| return false; | |||||
| } | |||||
| string debug_info_name = ParseParameterName(value_proto.name()); | |||||
| auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name); | |||||
| node->set_debug_info(debug_info_ptr); | |||||
| node->set_name(value_proto.name()); | |||||
| const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); | |||||
| tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); | |||||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||||
| node->set_abstract(tensor_abstract); | |||||
| anfnode_build_map_[value_proto.name()] = node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const mind_ir::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size(); | |||||
| for (int i = 0; i < importProto.parameter_size(); ++i) { | |||||
| const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i); | |||||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) { | |||||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "All inputs size is: " << importProto.input_size(); | |||||
| for (int i = 0; i < importProto.input_size(); ++i) { | |||||
| const mind_ir::ValueInfoProto &input_proto = importProto.input(i); | |||||
| if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { | |||||
| MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| const int attr_tensor_type = attr_proto.tensors(0).data_type(); | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; | |||||
| return false; | |||||
| } | |||||
| prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| return true; | |||||
| } | |||||
| ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) { | |||||
| const int attr_type = attr_proto.type(); | |||||
| switch (attr_type) { | |||||
| case mind_ir::AttributeProto_AttributeType_STRING: { | |||||
| return ParseAttrInScalar_string_string(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT8: { | |||||
| return ParseAttrInScalar_int8_t_int8_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT16: { | |||||
| return ParseAttrInScalar_int16_t_int16_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT32: { | |||||
| return ParseAttrInScalar_int32_t_int32_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT64: { | |||||
| return ParseAttrInScalar_int64_t_int64_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT8: { | |||||
| return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT16: { | |||||
| return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT32: { | |||||
| return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT64: { | |||||
| return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_FLOAT: { | |||||
| return ParseAttrInScalar_float_float(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_DOUBLE: { | |||||
| return ParseAttrInScalar_double_double(attr_proto, index); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_BOOL: { | |||||
| return ParseAttrInScalar_int32_t_bool(attr_proto, index); | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type; | |||||
| return {}; | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, | |||||
| std::unordered_map<std::string, ValuePtr> *multi_value_map) { | |||||
| string name; | |||||
| for (int i = 0; i < attr_proto.ints_size(); i++) { | |||||
| auto res = ParseAttrInScalarForm(attr_proto, i); | |||||
| name = "value" + std::to_string(i + 1); | |||||
| multi_value_map->insert(std::pair<string, ValuePtr>(name, res)); | |||||
| } | |||||
| for (int i = 0; i < attr_proto.doubles_size(); i++) { | |||||
| auto res = ParseAttrInScalarForm(attr_proto, i); | |||||
| name = "value" + std::to_string(i + 1); | |||||
| multi_value_map->insert(std::pair<string, ValuePtr>(name, res)); | |||||
| } | |||||
| for (int i = 0; i < attr_proto.floats_size(); i++) { | |||||
| auto res = ParseAttrInScalarForm(attr_proto, i); | |||||
| name = "value" + std::to_string(i + 1); | |||||
| multi_value_map->insert(std::pair<string, ValuePtr>(name, res)); | |||||
| } | |||||
| for (int i = 0; i < attr_proto.strings_size(); i++) { | |||||
| auto res = ParseAttrInScalarForm(attr_proto, i); | |||||
| name = "value" + std::to_string(i + 1); | |||||
| multi_value_map->insert(std::pair<string, ValuePtr>(name, res)); | |||||
| } | |||||
| } | |||||
| ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) { | |||||
| const int attr_type = attr_proto.type(); | |||||
| switch (attr_type) { | |||||
| case mind_ir::AttributeProto_AttributeType_STRING: { | |||||
| return ParseAttrInSingleScalar_string_string(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT8: { | |||||
| return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT16: { | |||||
| return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT32: { | |||||
| return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_INT64: { | |||||
| return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT8: { | |||||
| return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT16: { | |||||
| return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT32: { | |||||
| return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_UINT64: { | |||||
| return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_FLOAT: { | |||||
| return ParseAttrInSingleScalar_float_float(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_DOUBLE: { | |||||
| return ParseAttrInSingleScalar_double_double(attr_proto); | |||||
| } | |||||
| case mind_ir::AttributeProto_AttributeType_BOOL: { | |||||
| return ParseAttrInSingleScalar_int32_t_bool(attr_proto); | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type; | |||||
| return {}; | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, | |||||
| const mind_ir::AttributeProto &attr_proto) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0); | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| ShapeVector shape; | |||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | |||||
| shape.push_back(attr_tensor.dims(i)); | |||||
| } | |||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||||
| } | |||||
| prim->AddAttr(attr_proto.name(), MakeValue(tensor_info)); | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| const std::string &attr_name = attr_proto.name(); | |||||
| if (!attr_proto.has_ref_attr_name()) { | |||||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| const std::string &ref_attr_name = attr_proto.ref_attr_name(); | |||||
| string type = ""; | |||||
| std::size_t pos(0); | |||||
| if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("scalar:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("type:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("tensor:").length() - 1); | |||||
| } | |||||
| std::unordered_map<std::string, ValuePtr> multi_value_map; | |||||
| switch (kParseTypeSwitchMap[type]) { | |||||
| case FORM_PARSE_TYPE: { | |||||
| ObtainCNodeAttrInTypeForm(prim, attr_proto); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_SCALAR: { | |||||
| std::size_t value_pos(0); | |||||
| if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) { | |||||
| auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto); | |||||
| prim->AddAttr(attr_name, res); | |||||
| break; | |||||
| } | |||||
| ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_TENSOR: { | |||||
| ObtainCNodeAttrInTensorForm(prim, attr_proto); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name; | |||||
| return false; | |||||
| } | |||||
| if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) { | |||||
| if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) { | |||||
| auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map); | |||||
| prim->AddAttr(attr_name, value_tuple_ptr); | |||||
| } else { | |||||
| auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map); | |||||
| prim->AddAttr(attr_name, value_list_ptr); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||||
| const mind_ir::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| ShapeVector shape; | |||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | |||||
| shape.push_back(attr_tensor.dims(i)); | |||||
| } | |||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||||
| } | |||||
| auto new_value_node = NewValueNode(MakeValue(tensor_info)); | |||||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||||
| new_value_node->set_abstract(tensor_abstract); | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||||
| const mind_ir::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | |||||
| return false; | |||||
| } | |||||
| auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||||
| new_value_node->set_abstract(abs_type); | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name, | |||||
| const mind_ir::AttributeProto &attr_proto) { | |||||
| auto new_value_node = NewValueNode(kNone); | |||||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name, | |||||
| const mind_ir::AttributeProto &attr_proto) { | |||||
| if (!attr_proto.has_ref_attr_name()) { | |||||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| const std::string &ref_attr_name = attr_proto.ref_attr_name(); | |||||
| string type = ""; | |||||
| std::size_t pos(0); | |||||
| if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("scalar:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("type:").length() - 1); | |||||
| } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { | |||||
| type = ref_attr_name.substr(pos, string("tensor:").length() - 1); | |||||
| } else if (ref_attr_name == "none") { | |||||
| type = ref_attr_name; | |||||
| } | |||||
| ValueNodePtr new_value_node; | |||||
| std::unordered_map<std::string, ValuePtr> multi_value_map; | |||||
| switch (kParseTypeSwitchMap[type]) { | |||||
| case FORM_PARSE_TYPE: { | |||||
| ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0)); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_SCALAR: { | |||||
| std::size_t value_pos(0); | |||||
| if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) { | |||||
| auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto); | |||||
| new_value_node = NewValueNode(res); | |||||
| new_value_node->set_abstract(res->ToAbstract()); | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| break; | |||||
| } | |||||
| ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_TENSOR: { | |||||
| ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0)); | |||||
| break; | |||||
| } | |||||
| case FORM_PARSE_NONE: { | |||||
| ObtainValueNodeInNoneForm(value_node_name, attr_proto); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name; | |||||
| return false; | |||||
| } | |||||
| if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) { | |||||
| if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) { | |||||
| auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map); | |||||
| new_value_node = NewValueNode(value_tuple_ptr); | |||||
| new_value_node->set_abstract(value_tuple_ptr->ToAbstract()); | |||||
| } else { | |||||
| auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map); | |||||
| new_value_node = NewValueNode(value_list_ptr); | |||||
| new_value_node->set_abstract(value_list_ptr->ToAbstract()); | |||||
| } | |||||
| anfnode_build_map_[value_node_name] = new_value_node; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) { | |||||
| const std::string &value_node_name = node_proto.output(0); | |||||
| const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0); | |||||
| if (!attr_proto.has_ref_attr_name()) { | |||||
| MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; | |||||
| return false; | |||||
| } | |||||
| return GetAttrValueForValueNode(value_node_name, attr_proto); | |||||
| } | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode( | |||||
| const mind_ir::AttributeProto &attr_proto) { | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> kv; | |||||
| for (int i = 0; i < attr_proto.tensors_size(); ++i) { | |||||
| ShapeVector shape_vec; | |||||
| const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i); | |||||
| for (int j = 0; j < attr_tensor.dims_size(); ++j) { | |||||
| shape_vec.push_back(attr_tensor.dims(j)); | |||||
| } | |||||
| tensor::TensorPtr tensor_info = | |||||
| std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||||
| auto abstract = tensor_info->ToAbstract(); | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract)); | |||||
| } | |||||
| return kv; | |||||
| } | |||||
| CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const mind_ir::NodeProto &node_proto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| if (!node_proto.has_op_type()) { | |||||
| MS_LOG(ERROR) << "Get CNode op_type failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| const std::string &node_name = node_proto.output(0); | |||||
| const std::string &fullname_with_scope = node_proto.domain(); | |||||
| const std::string &node_type = node_proto.op_type(); | |||||
| std::shared_ptr<Primitive> prim; | |||||
| auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap(); | |||||
| if (op_primc_fns.find(node_type) != op_primc_fns.end()) { | |||||
| prim = op_primc_fns[node_type](); | |||||
| } else { | |||||
| prim = std::make_shared<Primitive>(node_type); | |||||
| prim->set_instance_name(node_type); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> kv; | |||||
| string shape_ref_attr_name; | |||||
| for (int i = 0; i < node_proto.attribute_size(); ++i) { | |||||
| const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); | |||||
| if (attr_proto.ref_attr_name().find("shape:") != string::npos) { | |||||
| shape_ref_attr_name = attr_proto.ref_attr_name(); | |||||
| kv = GetAbstractForCNode(attr_proto); | |||||
| continue; | |||||
| } | |||||
| if (!GetAttrValueForCNode(prim, attr_proto)) { | |||||
| MS_LOG(ERROR) << "Get CNode attr failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.clear(); | |||||
| for (int i = 0; i < node_proto.input_size(); ++i) { | |||||
| const std::string &input_name = node_proto.input(i); | |||||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | |||||
| MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; | |||||
| return nullptr; | |||||
| } | |||||
| inputs.push_back(anfnode_build_map_[input_name]); | |||||
| } | |||||
| auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs); | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| if (0 == kv.size()) { | |||||
| AbstractBasePtrList elem; | |||||
| for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { | |||||
| elem.push_back(cnode_ptr->input(index)->abstract()); | |||||
| } | |||||
| cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| } else if (1 == kv.size()) { | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin(); | |||||
| cnode_ptr->set_abstract(iter->second); | |||||
| } else { | |||||
| auto abstract = ParserAttrShape(shape_ref_attr_name, kv); | |||||
| cnode_ptr->set_abstract(abstract); | |||||
| } | |||||
| string debug_info_name = ParseCNodeName(node_name); | |||||
| auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name); | |||||
| cnode_ptr->set_debug_info(debug_info_ptr); | |||||
| cnode_ptr->set_fullname_with_scope(fullname_with_scope); | |||||
| anfnode_build_map_[node_name] = cnode_ptr; | |||||
| return cnode_ptr; | |||||
| } | |||||
| bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| if (importProto.output_size() > 1) { | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| AbstractBasePtrList elem; | |||||
| for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { | |||||
| const mind_ir::ValueInfoProto &output_node = importProto.output(out_size); | |||||
| const std::string &out_tuple = output_node.name(); | |||||
| inputs.push_back(anfnode_build_map_[out_tuple]); | |||||
| elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | |||||
| } | |||||
| auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | |||||
| maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | |||||
| inputs.push_back(maketuple_ptr); | |||||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| outputFuncGraph->set_return(return_node); | |||||
| MS_LOG(INFO) << "Construct funcgraph finined, all success."; | |||||
| } else { | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | |||||
| inputs.push_back(cnode_ptr); | |||||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| outputFuncGraph->set_return(return_node); | |||||
| MS_LOG(INFO) << "Construct funcgraph finined, all success!"; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const mind_ir::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | |||||
| CNodePtr cnode_ptr = nullptr; | |||||
| for (int i = 0; i < importProto.node_size(); ++i) { | |||||
| const mind_ir::NodeProto &node_proto = importProto.node(i); | |||||
| const std::string &node_type = node_proto.op_type(); | |||||
| if (node_type == kConstantValueNode) { | |||||
| if (!BuildValueNodeForFuncGraph(node_proto)) { | |||||
| MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; | |||||
| return false; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); | |||||
| if (cnode_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); | |||||
| return true; | |||||
| } | |||||
| bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||||
| GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | |||||
| MS_EXCEPTION_IF_NULL(debug_info_ptr); | |||||
| if (importProto.has_name()) { | |||||
| debug_info_ptr->set_name(importProto.name()); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "FuncGraph under converting has not name!"; | |||||
| } | |||||
| if (!ImportParametersForGraph(outputFuncGraph, importProto)) { | |||||
| MS_LOG(ERROR) << "import parameters for graph fail!"; | |||||
| return false; | |||||
| } | |||||
| return ImportNodesForGraph(outputFuncGraph, importProto); | |||||
| } | |||||
| bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto) { | |||||
| if (!model_proto.has_producer_name()) { | |||||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | |||||
| return false; | |||||
| } | |||||
| producer_name_ = model_proto.producer_name(); | |||||
| MS_LOG(INFO) << "producer_name :" << producer_name_; | |||||
| if (!model_proto.has_model_version()) { | |||||
| MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; | |||||
| return false; | |||||
| } | |||||
| model_version_ = model_proto.model_version(); | |||||
| MS_LOG(INFO) << "producer_version : " << model_version_; | |||||
| if (!model_proto.has_ir_version()) { | |||||
| MS_LOG(ERROR) << "Parse model version from pb file failed!"; | |||||
| return false; | |||||
| } | |||||
| ir_version_ = model_proto.ir_version(); | |||||
| MS_LOG(INFO) << "ir_version :" << ir_version_; | |||||
| return true; | |||||
| } | |||||
| FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { | |||||
| FuncGraphPtr dstGraph = std::make_shared<FuncGraph>(); | |||||
| MS_EXCEPTION_IF_NULL(dstGraph); | |||||
| if (!MSANFParseModelConfigureInfo(model_proto)) { | |||||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | |||||
| } | |||||
| const mind_ir::GraphProto &graphBuild = model_proto.graph(); | |||||
| if (!BuildFuncGraph(dstGraph, graphBuild)) { | |||||
| MS_LOG(ERROR) << "Build funcgraph failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; | |||||
| return dstGraph; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -14,63 +14,62 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H | |||||
| #define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H | |||||
| #ifndef MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H | |||||
| #define MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H | |||||
| #include <string> | #include <string> | ||||
| #include <map> | #include <map> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | #include "google/protobuf/io/zero_copy_stream_impl.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "proto/onnx.pb.h" | |||||
| #include "proto/mind_ir.pb.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | |||||
| using int32 = int32_t; | using int32 = int32_t; | ||||
| using int64 = int64_t; | using int64 = int64_t; | ||||
| using uint64 = uint64_t; | using uint64 = uint64_t; | ||||
| class MSANFModelParser { | class MSANFModelParser { | ||||
| public: | public: | ||||
| MSANFModelParser() : producer_name_(""), model_version_(0), ir_version_(0) {} | |||||
| MSANFModelParser() : producer_name_(""), model_version_(""), ir_version_("") {} | |||||
| ~MSANFModelParser() = default; | ~MSANFModelParser() = default; | ||||
| FuncGraphPtr Parse(const onnx::ModelProto &model_proto); | |||||
| bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto); | |||||
| FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto); | |||||
| bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto); | |||||
| std::string GetProducerName() { return producer_name_; } | std::string GetProducerName() { return producer_name_; } | ||||
| int GetProducerVersion() { return model_version_; } | |||||
| int GetIrVersion() { return ir_version_; } | |||||
| std::string GetProducerVersion() { return model_version_; } | |||||
| std::string GetIrVersion() { return ir_version_; } | |||||
| private: | private: | ||||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | |||||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); | |||||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); | |||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); | |||||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); | |||||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto); | |||||
| bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto); | |||||
| tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto); | |||||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto); | |||||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto, | |||||
| const CNodePtr &cnode_ptr); | const CNodePtr &cnode_ptr); | ||||
| bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); | |||||
| bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); | |||||
| bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); | |||||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_tensor); | |||||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); | |||||
| bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); | |||||
| void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, | |||||
| std::unordered_map<std::string, ValuePtr> *multi_value_map); | |||||
| ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index); | |||||
| ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto); | |||||
| bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); | |||||
| bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto); | |||||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); | |||||
| bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor); | |||||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); | |||||
| bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode( | std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode( | ||||
| const onnx::AttributeProto &attr_proto); | |||||
| const mind_ir::AttributeProto &attr_proto); | |||||
| std::string producer_name_; | std::string producer_name_; | ||||
| int model_version_; | |||||
| int ir_version_; | |||||
| std::string model_version_; | |||||
| std::string ir_version_; | |||||
| std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; | std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; | ||||
| std::map<std::string, onnx::TensorProto> default_para_map_; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H | |||||
| #endif // MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H | |||||
| @@ -0,0 +1,102 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "load_mindir/load_model.h" | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include <fstream> | |||||
| #include "load_mindir/anf_model_parser.h" | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace mindspore { | |||||
| std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) { | |||||
| if (file.empty()) { | |||||
| MS_LOG(ERROR) << "file is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| char real_path[PATH_MAX] = {0}; | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) { | |||||
| MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file; | |||||
| return nullptr; | |||||
| } | |||||
| #else | |||||
| if (realpath(file.c_str(), real_path) == nullptr) { | |||||
| MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file; | |||||
| return nullptr; | |||||
| } | |||||
| #endif | |||||
| std::ifstream ifs(real_path); | |||||
| if (!ifs.good()) { | |||||
| MS_LOG(ERROR) << "file: " << real_path << " is not exist"; | |||||
| return nullptr; | |||||
| } | |||||
| if (!ifs.is_open()) { | |||||
| MS_LOG(ERROR) << "file: " << real_path << "open failed"; | |||||
| return nullptr; | |||||
| } | |||||
| ifs.seekg(0, std::ios::end); | |||||
| size_t size = ifs.tellg(); | |||||
| std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size)); | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; | |||||
| ifs.close(); | |||||
| return nullptr; | |||||
| } | |||||
| ifs.seekg(0, std::ios::beg); | |||||
| ifs.read(buf->data(), size); | |||||
| ifs.close(); | |||||
| return buf; | |||||
| } | |||||
| std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name) { | |||||
| auto graphBuf = ReadProtoFile(file_name); | |||||
| if (graphBuf == nullptr) { | |||||
| MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str(); | |||||
| return nullptr; | |||||
| } | |||||
| try { | |||||
| auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size()); | |||||
| return graph; | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) { | |||||
| MS_EXCEPTION_IF_NULL(buf); | |||||
| std::string str((const char *)buf, buf_size); | |||||
| mind_ir::ModelProto model_; | |||||
| if (!model_.ParseFromString(str)) { | |||||
| MS_LOG(ERROR) << "Parse model from buffer fail!"; | |||||
| } | |||||
| MSANFModelParser model_parser; | |||||
| FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); | |||||
| return dstgraph_ptr; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -13,27 +13,19 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_LOAD_MODEL_H | |||||
| #define MINDSPORE_CORE_LOAD_MODEL_H | |||||
| #ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H | |||||
| #define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H | |||||
| #include <vector> | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||||
| #include "proto/onnx.pb.h" | |||||
| #include "proto/mind_ir.pb.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | |||||
| class AnfConverter { | |||||
| public: | |||||
| static std::shared_ptr<FuncGraph> RunAnfConverter(const std::string &file_path); | |||||
| static std::shared_ptr<FuncGraph> RunAnfConverter(const char *buf, const size_t buf_size); | |||||
| private: | |||||
| static void Trim(std::string *input); | |||||
| static int ValidateFileStr(const std::string &modelFile, std::string fileType); | |||||
| static bool ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model); | |||||
| }; | |||||
| } // namespace lite | |||||
| std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name); | |||||
| std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file); | |||||
| std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif | |||||
| #endif // MINDSPORE_CORE_LOAD_MODEL_H | |||||
| @@ -116,7 +116,9 @@ endif () | |||||
| file(GLOB PROTO_FILE "" | file(GLOB PROTO_FILE "" | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto | ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto | |||||
| ${CCSRC_DIR}/utils/mind_ir.proto) | |||||
| ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | ||||
| add_library(proto_mid OBJECT ${PROTO_SRCS}) | add_library(proto_mid OBJECT ${PROTO_SRCS}) | ||||
| set(TFLITE_FBS_FILES | set(TFLITE_FBS_FILES | ||||
| @@ -23,9 +23,12 @@ from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | |||||
| from mindspore.train.anf_ir_pb2 import ModelProto as anf_model | |||||
| from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | ||||
| def _convert_type(types): | def _convert_type(types): | ||||
| """ | """ | ||||
| Convert from numpy type to tensor type. | Convert from numpy type to tensor type. | ||||
| @@ -203,3 +206,32 @@ def check_value_type(arg_name, arg_value, valid_types): | |||||
| if not is_valid: | if not is_valid: | ||||
| raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, ' | raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, ' | ||||
| f'bug got {type(arg_value).__name__}.') | f'bug got {type(arg_value).__name__}.') | ||||
| def read_proto(file_name, proto_format="MINDIR"): | |||||
| """ | |||||
| Read protobuf file. | |||||
| Args: | |||||
| file_name (str): File name. | |||||
| proto_format (str): Proto format. | |||||
| Returns: | |||||
| Object, proto object. | |||||
| """ | |||||
| if proto_format == "MINDIR": | |||||
| model = mindir_model() | |||||
| elif model_format == "ANF": | |||||
| model = anf_model() | |||||
| else: | |||||
| raise ValueError("Unsupported proto format.") | |||||
| try: | |||||
| with open(file_name, "rb") as f: | |||||
| pb_content = f.read() | |||||
| model.ParseFromString(pb_content) | |||||
| except BaseException as e: | |||||
| logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) | |||||
| raise ValueError(e.__str__()) | |||||
| return model | |||||