| @@ -230,6 +230,7 @@ if(BUILD_CONVERTER) | |||||
| ${TEST_LITE_SRC} | ${TEST_LITE_SRC} | ||||
| ${TEST_CASE_TFLITE_PARSERS_SRC} | ${TEST_CASE_TFLITE_PARSERS_SRC} | ||||
| ${TOP_DIR}/mindspore/core/utils/flags.cc | ${TOP_DIR}/mindspore/core/utils/flags.cc | ||||
| ${LITE_DIR}/tools/common/protobuf_utils.cc | |||||
| ${LITE_DIR}/tools/converter/optimizer.cc | ${LITE_DIR}/tools/converter/optimizer.cc | ||||
| ${LITE_DIR}/tools/converter/anf_transform.cc | ${LITE_DIR}/tools/converter/anf_transform.cc | ||||
| ${LITE_DIR}/tools/converter/graphdef_transform.cc | ${LITE_DIR}/tools/converter/graphdef_transform.cc | ||||
| @@ -27,7 +27,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -37,6 +36,7 @@ | |||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "tools/common/protobuf_utils.h" | |||||
| using string = std::string; | using string = std::string; | ||||
| using int32 = int32_t; | using int32 = int32_t; | ||||
| @@ -651,31 +651,11 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||||
| } | } | ||||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | ||||
| std::unique_ptr<char[]> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | |||||
| #ifdef _WIN32 | |||||
| if (_fullpath(onnx_file.get(), model_path.c_str(), 1024) == nullptr) { | |||||
| MS_LOG(ERROR) << "open file failed."; | |||||
| return nullptr; | |||||
| } | |||||
| #else | |||||
| if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { | |||||
| MS_LOG(ERROR) << "open file failed."; | |||||
| return nullptr; | |||||
| } | |||||
| #endif | |||||
| int fd = open(onnx_file.get(), O_RDONLY); | |||||
| google::protobuf::io::FileInputStream input(fd); | |||||
| google::protobuf::io::CodedInputStream code_input(&input); | |||||
| code_input.SetTotalBytesLimit(INT_MAX, 536870912); | |||||
| auto onnx_model = new onnx::ModelProto; | auto onnx_model = new onnx::ModelProto; | ||||
| bool ret = onnx_model->ParseFromCodedStream(&code_input); | |||||
| if (!ret) { | |||||
| MS_LOG(ERROR) << "load onnx file failed"; | |||||
| delete onnx_model; | |||||
| if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { | |||||
| MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| (void)close(fd); | |||||
| MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; | |||||
| return onnx_model; | return onnx_model; | ||||
| } | } | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" | |||||
| #include "tools/common/protobuf_utils.h" | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <string> | #include <string> | ||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | #include "google/protobuf/io/zero_copy_stream_impl.h" | ||||
| @@ -37,15 +37,14 @@ bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded | |||||
| return proto->ParseFromCodedStream(coded_stream); | return proto->ParseFromCodedStream(coded_stream); | ||||
| } | } | ||||
| STATUS ReadProtoFromText(const char *file, | |||||
| google::protobuf::Message *message) { | |||||
| STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message) { | |||||
| if (file == nullptr || message == nullptr) { | if (file == nullptr || message == nullptr) { | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::string realPath = RealPath(file); | std::string realPath = RealPath(file); | ||||
| if (realPath.empty()) { | if (realPath.empty()) { | ||||
| MS_LOG(ERROR) << "Proto file path " << file <<" is not valid"; | |||||
| MS_LOG(ERROR) << "Proto file path " << file << " is not valid"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -67,8 +66,7 @@ STATUS ReadProtoFromText(const char *file, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS ReadProtoFromBinaryFile(const char *file, | |||||
| google::protobuf::Message *message) { | |||||
| STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message) { | |||||
| if (file == nullptr || message == nullptr) { | if (file == nullptr || message == nullptr) { | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -100,4 +98,3 @@ STATUS ReadProtoFromBinaryFile(const char *file, | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,13 +29,10 @@ namespace lite { | |||||
| bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, | bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, | ||||
| google::protobuf::Message *proto); | google::protobuf::Message *proto); | ||||
| STATUS ReadProtoFromText(const char *file, | |||||
| google::protobuf::Message *message); | |||||
| STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message); | |||||
| STATUS ReadProtoFromBinaryFile(const char *file, | |||||
| google::protobuf::Message *message); | |||||
| STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ | ||||
| @@ -94,6 +94,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc | ||||
| @@ -15,7 +15,6 @@ add_library(caffe_parser_mid OBJECT | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.cc | ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_parse_utils.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.cc | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h" | #include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h" | ||||
| #include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -14,14 +14,14 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h" | |||||
| #include "tools/converter/parser/caffe/caffe_model_parser.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <utility> | #include <utility> | ||||
| #include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" | |||||
| #include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" | |||||
| #include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h" | |||||
| #include "tools/converter/parser/caffe/caffe_node_parser_registry.h" | |||||
| #include "tools/converter/parser/caffe/caffe_inspector.h" | |||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "tools/common/protobuf_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -31,9 +31,8 @@ CaffeModelParser::~CaffeModelParser() {} | |||||
| const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; | const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; | ||||
| schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, | |||||
| const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { | if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -89,8 +88,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, | |||||
| return metaGraph.release(); | return metaGraph.release(); | ||||
| } | } | ||||
| STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, | |||||
| schema::CNodeT *op, | |||||
| STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, | |||||
| TensorCache *tensorCache) { | TensorCache *tensorCache) { | ||||
| for (int i = 0; i < layer.bottom_size(); i++) { | for (int i = 0; i < layer.bottom_size(); i++) { | ||||
| int index = tensorCache->FindTensor(layer.bottom(i)); | int index = tensorCache->FindTensor(layer.bottom(i)); | ||||
| @@ -104,8 +102,7 @@ STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, | |||||
| schema::CNodeT *op, | |||||
| STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, | |||||
| TensorCache *tensorCache) { | TensorCache *tensorCache) { | ||||
| for (int i = 0; i < layer.top_size(); i++) { | for (int i = 0; i < layer.top_size(); i++) { | ||||
| std::unique_ptr<schema::TensorT> msTensor = std::make_unique<schema::TensorT>(); | std::unique_ptr<schema::TensorT> msTensor = std::make_unique<schema::TensorT>(); | ||||
| @@ -114,8 +111,7 @@ STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &weightVec, | |||||
| schema::CNodeT *op, | |||||
| STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &weightVec, schema::CNodeT *op, | |||||
| TensorCache *tensorCache) { | TensorCache *tensorCache) { | ||||
| for (auto iter : weightVec) { | for (auto iter : weightVec) { | ||||
| op->inputIndex.emplace_back(tensorCache->AddTensor("Weight", iter, CONST)); | op->inputIndex.emplace_back(tensorCache->AddTensor("Weight", iter, CONST)); | ||||
| @@ -123,8 +119,7 @@ STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &w | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, | |||||
| schema::MetaGraphT *subGraphDef) { | |||||
| STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, schema::MetaGraphT *subGraphDef) { | |||||
| std::vector<schema::TensorT *> tensors = tensorCache.GetCachedTensor(); | std::vector<schema::TensorT *> tensors = tensorCache.GetCachedTensor(); | ||||
| for (auto iter : tensors) { | for (auto iter : tensors) { | ||||
| std::unique_ptr<schema::TensorT> temp(iter); | std::unique_ptr<schema::TensorT> temp(iter); | ||||
| @@ -133,8 +128,7 @@ STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, | |||||
| TensorCache *tensorCache, | |||||
| STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, TensorCache *tensorCache, | |||||
| schema::MetaGraphT *subGraphDef) { | schema::MetaGraphT *subGraphDef) { | ||||
| CaffeInspector caffeInspector; | CaffeInspector caffeInspector; | ||||
| caffeInspector.InspectModel(proto); | caffeInspector.InspectModel(proto); | ||||
| @@ -160,10 +154,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, | |||||
| const caffe::NetParameter &weight, | |||||
| TensorCache *tensorCache, | |||||
| schema::MetaGraphT *subGraphDef) { | |||||
| STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, | |||||
| TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) { | |||||
| for (int i = 0; i < proto.layer_size(); i++) { | for (int i = 0; i < proto.layer_size(); i++) { | ||||
| auto layer = proto.layer(i); | auto layer = proto.layer(i); | ||||
| @@ -235,8 +227,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, | |||||
| TensorCache *tensorCache) { | |||||
| STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { | |||||
| for (int i = 0; i < proto.input_size(); i++) { | for (int i = 0; i < proto.input_size(); i++) { | ||||
| if (proto.input_dim_size() <= 0) { | if (proto.input_dim_size() <= 0) { | ||||
| continue; | continue; | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "tools/common/protobuf_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -54,36 +55,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo | |||||
| return dims; | return dims; | ||||
| } | } | ||||
| STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, | |||||
| google::protobuf::Message *onnx_model) { | |||||
| std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | |||||
| #ifdef _WIN32 | |||||
| if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) { | |||||
| MS_LOG(ERROR) << "get realpath " << modelFile << " fail"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| #else | |||||
| if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) { | |||||
| MS_LOG(ERROR) << "get realpath " << modelFile << " fail"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| #endif | |||||
| int fd = open(onnx_file.get(), O_RDONLY); | |||||
| google::protobuf::io::FileInputStream input(fd); | |||||
| google::protobuf::io::CodedInputStream code_input(&input); | |||||
| code_input.SetTotalBytesLimit(INT_MAX, 536870912); | |||||
| bool ret = onnx_model->ParseFromCodedStream(&code_input); | |||||
| if (!ret) { | |||||
| MS_LOG(ERROR) << "load onnx file failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| (void)close(fd); | |||||
| onnx_file.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, | |||||
| TensorCache *tensor_cache) { | |||||
| STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { | |||||
| MS_LOG(DEBUG) << "set onnx constant tensors"; | MS_LOG(DEBUG) << "set onnx constant tensors"; | ||||
| for (const auto &onnx_const_value : onnx_graph.initializer()) { | for (const auto &onnx_const_value : onnx_graph.initializer()) { | ||||
| int index; | int index; | ||||
| @@ -119,11 +91,8 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, | |||||
| const std::string &name, | |||||
| const TensorType &type, | |||||
| TensorCache *tensor_cache, | |||||
| int *index) { | |||||
| STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, | |||||
| TensorCache *tensor_cache, int *index) { | |||||
| auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); | auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); | ||||
| if (data_type == kTypeUnknown) { | if (data_type == kTypeUnknown) { | ||||
| MS_LOG(ERROR) << "not support onnx data type " | MS_LOG(ERROR) << "not support onnx data type " | ||||
| @@ -143,11 +112,8 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, | |||||
| const std::string &name, | |||||
| const TensorType &type, | |||||
| TensorCache *tensor_cache, | |||||
| int *index) { | |||||
| STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type, | |||||
| TensorCache *tensor_cache, int *index) { | |||||
| auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type())); | auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type())); | ||||
| if (data_type == kTypeUnknown) { | if (data_type == kTypeUnknown) { | ||||
| MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type()); | MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type()); | ||||
| @@ -174,8 +140,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, | |||||
| schema::MetaGraphT *graph, | |||||
| STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, | |||||
| TensorCache *tensor_cache) { | TensorCache *tensor_cache) { | ||||
| for (const auto &input_value : onnx_graph.input()) { | for (const auto &input_value : onnx_graph.input()) { | ||||
| auto ret = tensor_cache->FindTensor(input_value.name()); | auto ret = tensor_cache->FindTensor(input_value.name()); | ||||
| @@ -192,8 +157,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, | |||||
| schema::MetaGraphT *graph, | |||||
| STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, | |||||
| TensorCache *tensor_cache) { | TensorCache *tensor_cache) { | ||||
| for (const auto &output_value : onnx_graph.output()) { | for (const auto &output_value : onnx_graph.output()) { | ||||
| int index; | int index; | ||||
| @@ -207,10 +171,8 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| schema::MetaGraphT *graph, | |||||
| TensorCache *tensor_cache) { | |||||
| void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::MetaGraphT *graph, TensorCache *tensor_cache) { | |||||
| std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>(); | std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>(); | ||||
| dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); | dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); | ||||
| ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); | ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); | ||||
| @@ -231,8 +193,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, | |||||
| graph->nodes.emplace_back(std::move(dst_op_2)); | graph->nodes.emplace_back(std::move(dst_op_2)); | ||||
| } | } | ||||
| STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||||
| TensorCache *tensor_cache) { | |||||
| STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { | |||||
| // convert GivenTensorFill node to a weight/bias tensor | // convert GivenTensorFill node to a weight/bias tensor | ||||
| auto ret = tensor_cache->FindTensor(onnx_node.output(0)); | auto ret = tensor_cache->FindTensor(onnx_node.output(0)); | ||||
| if (ret < 0) { | if (ret < 0) { | ||||
| @@ -284,10 +245,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, | |||||
| schema::TensorT *dst_tensor, | |||||
| STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, | |||||
| TensorCache *tensor_cache) { | TensorCache *tensor_cache) { | ||||
| // change op_type() to name(), that is unique | // change op_type() to name(), that is unique | ||||
| dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | ||||
| @@ -319,11 +278,8 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, | |||||
| schema::TensorT *dst_tensor, | |||||
| TensorCache *tensor_cache) { | |||||
| void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { | |||||
| MS_ASSERT(dst_op != nullptr); | MS_ASSERT(dst_op != nullptr); | ||||
| MS_ASSERT(tensor_cache != nullptr); | MS_ASSERT(tensor_cache != nullptr); | ||||
| std::vector<string> quant_node_name; | std::vector<string> quant_node_name; | ||||
| @@ -380,10 +336,8 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, | |||||
| } | } | ||||
| } | } | ||||
| STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| const string &onnx_op_type, | |||||
| schema::CNodeT *dst_op) { | |||||
| STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| const string &onnx_op_type, schema::CNodeT *dst_op) { | |||||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); | auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); | ||||
| if (node_parser == nullptr) { | if (node_parser == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr"; | MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr"; | ||||
| @@ -392,10 +346,8 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, | |||||
| return node_parser->Parse(onnx_graph, onnx_node, dst_op); | return node_parser->Parse(onnx_graph, onnx_node, dst_op); | ||||
| } | } | ||||
| STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, | |||||
| schema::CNodeT *dst_op, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| TensorCache *tensor_cache) { | |||||
| STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, | |||||
| const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { | |||||
| for (const auto &onnx_node_input : node_inputs) { | for (const auto &onnx_node_input : node_inputs) { | ||||
| auto index = tensor_cache->FindTensor(onnx_node_input); | auto index = tensor_cache->FindTensor(onnx_node_input); | ||||
| if (index < 0) { | if (index < 0) { | ||||
| @@ -408,8 +360,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, | |||||
| schema::CNodeT *dst_op, | |||||
| STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, | |||||
| TensorCache *tensor_cache) { | TensorCache *tensor_cache) { | ||||
| for (const auto &onnx_node_output : node_outputs) { | for (const auto &onnx_node_output : node_outputs) { | ||||
| auto index = tensor_cache->FindTensor(onnx_node_output); | auto index = tensor_cache->FindTensor(onnx_node_output); | ||||
| @@ -424,8 +375,7 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, | |||||
| schema::TensorT *tensor) { | |||||
| STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) { | |||||
| size_t data_count = 1; | size_t data_count = 1; | ||||
| std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); | std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); | ||||
| size_t data_size = 0; | size_t data_size = 0; | ||||
| @@ -484,8 +434,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, | |||||
| schema::MetaGraphT *graphDef) { | |||||
| STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { | |||||
| std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor(); | std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor(); | ||||
| for (auto iter : tensors) { | for (auto iter : tensors) { | ||||
| std::unique_ptr<schema::TensorT> temp(iter); | std::unique_ptr<schema::TensorT> temp(iter); | ||||
| @@ -507,17 +456,16 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) | |||||
| } | } | ||||
| } | } | ||||
| schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, | |||||
| const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { | if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { | ||||
| MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; | MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto dst_graph = std::make_unique<schema::MetaGraphT>(); | |||||
| onnx::ModelProto onnx_model; | onnx::ModelProto onnx_model; | ||||
| if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) { | |||||
| MS_LOG(ERROR) << "read onnx model fail"; | |||||
| if (ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model) != RET_OK) { | |||||
| MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| const onnx::GraphProto &onnx_graph = onnx_model.graph(); | const onnx::GraphProto &onnx_graph = onnx_model.graph(); | ||||
| @@ -531,6 +479,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, | |||||
| MS_LOG(ERROR) << "SetGraphConstTensor failed"; | MS_LOG(ERROR) << "SetGraphConstTensor failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto dst_graph = std::make_unique<schema::MetaGraphT>(); | |||||
| // init onnx model graph input tensor | // init onnx model graph input tensor | ||||
| if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { | if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { | ||||
| MS_LOG(ERROR) << "SetGraphInputTensor failed"; | MS_LOG(ERROR) << "SetGraphInputTensor failed"; | ||||
| @@ -41,78 +41,47 @@ class OnnxModelParser : public ModelParser { | |||||
| virtual ~OnnxModelParser(); | virtual ~OnnxModelParser(); | ||||
| schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | ||||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||||
| private: | private: | ||||
| TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | ||||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | ||||
| STATUS ReadOnnxModelFromBinary(const std::string &modelFile, | |||||
| google::protobuf::Message *model_proto); | |||||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, | |||||
| schema::MetaGraphT *graph, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, | |||||
| schema::MetaGraphT *graph, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS AddValueInfo(const onnx::ValueInfoProto &proto, | |||||
| const std::string &name, | |||||
| const TensorType &type, | |||||
| TensorCache *tensor_cache, | |||||
| int *index); | |||||
| STATUS AddTensorProto(const onnx::TensorProto &proto, | |||||
| const std::string &name, | |||||
| const TensorType &type, | |||||
| TensorCache *tensor_cache, | |||||
| int *index); | |||||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, | |||||
| schema::TensorT *dst_tensor, | |||||
| TensorCache *tensor_cache); | |||||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| schema::MetaGraphT *graph, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| const string &onnx_op_type, | |||||
| schema::CNodeT *dst_op); | |||||
| void SetOpQuantParams(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, | |||||
| schema::TensorT *dst_tensor, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS SetOpInputIndex(const std::vector<string> &node_inputs, | |||||
| schema::CNodeT *dst_op, | |||||
| const onnx::NodeProto &onnx_node, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, | |||||
| schema::CNodeT *dst_op, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, | |||||
| schema::TensorT *tensor); | |||||
| STATUS SetAllTensors(const TensorCache &tensor_cache, | |||||
| schema::MetaGraphT *graphDef); | |||||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); | |||||
| STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); | |||||
| STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); | |||||
| STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, | |||||
| TensorCache *tensor_cache, int *index); | |||||
| STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type, | |||||
| TensorCache *tensor_cache, int *index); | |||||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache); | |||||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::MetaGraphT *graph, TensorCache *tensor_cache); | |||||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | |||||
| STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| const string &onnx_op_type, schema::CNodeT *dst_op); | |||||
| void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, | |||||
| schema::TensorT *dst_tensor, TensorCache *tensor_cache); | |||||
| STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, | |||||
| const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | |||||
| STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); | |||||
| STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); | |||||
| STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); | |||||
| void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); | void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); | ||||