| @@ -24,7 +24,7 @@ | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "tools/converter/return_code.h" | |||||
| #include "tools/converter/converter_context.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfExporter { | class AnfExporter { | ||||
| @@ -47,7 +47,7 @@ class AnfExporter { | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | ||||
| void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | ||||
| int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *return_node); | |||||
| schema::CNodeT *return_node); | |||||
| bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | ||||
| int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | ||||
| const std::shared_ptr<PrimitiveC> primitive, const std::unique_ptr<schema::CNodeT> &dst_node); | const std::shared_ptr<PrimitiveC> primitive, const std::unique_ptr<schema::CNodeT> &dst_node); | ||||
| @@ -202,7 +202,7 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | ||||
| int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | ||||
| const onnx::ValueInfoProto &value_proto) { | |||||
| const onnx::ValueInfoProto &value_proto) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| @@ -273,7 +273,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node | |||||
| } | } | ||||
| int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | ||||
| const onnx::GraphProto &importProto) { | |||||
| const onnx::GraphProto &importProto) { | |||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| @@ -557,6 +557,7 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt | |||||
| CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | ||||
| const onnx::NodeProto &node_proto, | const onnx::NodeProto &node_proto, | ||||
| const schema::QuantType &quantType) { | const schema::QuantType &quantType) { | ||||
| static bool interrupt = false; | |||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| MS_LOG(ERROR) << "output funcgraph is nullptr"; | MS_LOG(ERROR) << "output funcgraph is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -600,13 +601,17 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||||
| inputs.push_back(anfnode_build_map_[input_name]); | inputs.push_back(anfnode_build_map_[input_name]); | ||||
| } | } | ||||
| auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType); | auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType); | ||||
| if (primitivec_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "Create PrimitiveC return nullptr, " << prim->name(); | |||||
| if (primitivec_ptr == nullptr || interrupt) { | |||||
| interrupt = true; | |||||
| if (primitivec_ptr == nullptr) { | |||||
| NoSupportOp::GetInstance()->InsertOp(prim->name()); | |||||
| } | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); | inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); | ||||
| CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); | CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); | ||||
| if (cnode_ptr == nullptr) { | if (cnode_ptr == nullptr) { | ||||
| interrupt = true; | |||||
| MS_LOG(ERROR) << "funcgraph new cnode failed"; | MS_LOG(ERROR) << "funcgraph new cnode failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -700,40 +705,43 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||||
| } | } | ||||
| int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | ||||
| const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| MS_LOG(ERROR) << "funcgraph is nullptr"; | MS_LOG(ERROR) << "funcgraph is nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | ||||
| CNodePtr cnode_ptr = nullptr; | CNodePtr cnode_ptr = nullptr; | ||||
| int status = RET_OK; | |||||
| for (int i = 0; i < importProto.node_size(); ++i) { | for (int i = 0; i < importProto.node_size(); ++i) { | ||||
| const onnx::NodeProto &node_proto = importProto.node(i); | const onnx::NodeProto &node_proto = importProto.node(i); | ||||
| const std::string &node_type = node_proto.op_type(); | const std::string &node_type = node_proto.op_type(); | ||||
| if (node_type == kConstantValueNode) { | if (node_type == kConstantValueNode) { | ||||
| if (!BuildValueNodeForFuncGraph(node_proto)) { | |||||
| if (status == RET_OK && !BuildValueNodeForFuncGraph(node_proto)) { | |||||
| MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; | MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; | ||||
| return RET_ERROR; | |||||
| status = RET_ERROR; | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); | cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); | ||||
| if (cnode_ptr == nullptr) { | if (cnode_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | ||||
| return RET_NULL_PTR; | |||||
| status = (status == RET_OK ? RET_NULL_PTR : status); | |||||
| } | } | ||||
| } | } | ||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | ||||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | ||||
| return RET_ERROR; | |||||
| status = RET_ERROR; | |||||
| } | } | ||||
| return RET_OK; | |||||
| return status; | |||||
| } | } | ||||
| int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | ||||
| const schema::QuantType &quantType) { | |||||
| const schema::QuantType &quantType) { | |||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| MS_LOG(ERROR) << "fundgraph is nullptr"; | MS_LOG(ERROR) << "fundgraph is nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "tools/converter/converter_context.h" | |||||
| #include "tools/anf_importer/anf_importer.h" | #include "tools/anf_importer/anf_importer.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| @@ -47,10 +48,10 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||||
| int AddReturnCNode() override { return RET_ERROR; }; | int AddReturnCNode() override { return RET_ERROR; }; | ||||
| int ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | int ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | ||||
| int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | ||||
| const schema::QuantType &quantType); | |||||
| const schema::QuantType &quantType); | |||||
| int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | ||||
| int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | ||||
| const schema::QuantType &quantType); | |||||
| const schema::QuantType &quantType); | |||||
| int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | ||||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | ||||
| const schema::QuantType &quantType); | const schema::QuantType &quantType); | ||||
| @@ -15,6 +15,8 @@ | |||||
| */ | */ | ||||
| #include "tools/common/storage.h" | #include "tools/common/storage.h" | ||||
| #include <sys/stat.h> | |||||
| #include <unistd.h> | |||||
| #include "flatbuffers/flatbuffers.h" | #include "flatbuffers/flatbuffers.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| @@ -31,7 +33,10 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath | |||||
| MS_LOG(ERROR) << "GetBufferPointer nullptr"; | MS_LOG(ERROR) << "GetBufferPointer nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (access((outputPath + ".ms").c_str(), F_OK) == 0) { | |||||
| MS_LOG(WARNING) << "this file " << outputPath << ".ms has been existed"; | |||||
| chmod((outputPath + ".ms").c_str(), S_IWUSR); | |||||
| } | |||||
| std::ofstream output(outputPath + ".ms", std::ofstream::binary); | std::ofstream output(outputPath + ".ms", std::ofstream::binary); | ||||
| if (!output.is_open()) { | if (!output.is_open()) { | ||||
| MS_LOG(ERROR) << "Can not open output file: " << outputPath << ".ms"; | MS_LOG(ERROR) << "Can not open output file: " << outputPath << ".ms"; | ||||
| @@ -40,6 +45,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath | |||||
| output.write((const char *)content, size); | output.write((const char *)content, size); | ||||
| output.close(); | output.close(); | ||||
| chmod((outputPath + ".ms").c_str(), S_IRUSR); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "tools/converter/converter_flags.h" | #include "tools/converter/converter_flags.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "tools/converter/quantizer/quantizer.h" | #include "tools/converter/quantizer/quantizer.h" | ||||
| #include "tools/converter/return_code.h" | |||||
| #include "tools/converter/converter_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -152,6 +152,7 @@ int RunConverter(int argc, const char **argv) { | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| } | } | ||||
| NoSupportOp::GetInstance()->PrintOps(); | |||||
| status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); | status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); | ||||
| if (fb_graph == nullptr) { | if (fb_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Convert model return nullptr"; | MS_LOG(ERROR) << "Convert model return nullptr"; | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "tools/anf_importer/anf_importer.h" | #include "tools/anf_importer/anf_importer.h" | ||||
| #include "tools/converter/converter_flags.h" | #include "tools/converter/converter_flags.h" | ||||
| #include "tools/converter/anf_transform.h" | #include "tools/converter/anf_transform.h" | ||||
| #include "tools/converter/return_code.h" | |||||
| #include "tools/converter/converter_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -17,13 +17,16 @@ | |||||
| #ifndef LITE_RETURN_CODE_H | #ifndef LITE_RETURN_CODE_H | ||||
| #define LITE_RETURN_CODE_H | #define LITE_RETURN_CODE_H | ||||
| #include <string> | |||||
| #include <set> | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class ReturnCode { | class ReturnCode { | ||||
| public: | public: | ||||
| ~ReturnCode() {} | |||||
| ~ReturnCode() = default; | |||||
| static ReturnCode *GetSingleReturnCode() { | static ReturnCode *GetSingleReturnCode() { | ||||
| static ReturnCode returnCode; | static ReturnCode returnCode; | ||||
| return &returnCode; | return &returnCode; | ||||
| @@ -33,15 +36,31 @@ class ReturnCode { | |||||
| statusCode = status; | statusCode = status; | ||||
| } | } | ||||
| } | } | ||||
| STATUS GetReturnCode() const { | |||||
| return statusCode; | |||||
| } | |||||
| STATUS GetReturnCode() const { return statusCode; } | |||||
| private: | private: | ||||
| ReturnCode() { statusCode = RET_OK; } | ReturnCode() { statusCode = RET_OK; } | ||||
| int statusCode; | int statusCode; | ||||
| }; | }; | ||||
| class NoSupportOp { | |||||
| public: | |||||
| ~NoSupportOp() = default; | |||||
| static NoSupportOp *GetInstance() { | |||||
| static NoSupportOp noSupportOp; | |||||
| return &noSupportOp; | |||||
| } | |||||
| void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); } | |||||
| void PrintOps() const { | |||||
| for (auto &op_name : noSupportOps) { | |||||
| MS_LOG(ERROR) << "The op " << op_name << " hasn't been supported"; | |||||
| } | |||||
| } | |||||
| private: | |||||
| NoSupportOp() { noSupportOps.clear(); } | |||||
| std::set<std::string> noSupportOps; | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_RETURN_CODE_H | #endif // LITE_RETURN_CODE_H | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "tools/anf_importer/import_from_meta_graphT.h" | #include "tools/anf_importer/import_from_meta_graphT.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "tools/converter/return_code.h" | |||||
| #include "tools/converter/converter_context.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| using namespace schema; | using namespace schema; | ||||
| @@ -40,7 +40,7 @@ class ModelParser { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto func_graph = this->Fb2Anf(meta_graph); | auto func_graph = this->Fb2Anf(meta_graph); | ||||
| delete(meta_graph); | |||||
| delete (meta_graph); | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| @@ -84,6 +84,9 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseLayer failed " << status; | MS_LOG(ERROR) << "ParseLayer failed " << status; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| for (auto &tensor : tensorCache.GetCachedTensor()) { | |||||
| delete tensor; | |||||
| } | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -179,6 +182,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, T | |||||
| STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, | STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, | ||||
| TensorCache *tensorCache, schema::MetaGraphT *subGraphDef, | TensorCache *tensorCache, schema::MetaGraphT *subGraphDef, | ||||
| const QuantType &quantType) { | const QuantType &quantType) { | ||||
| static bool interrupt = false; | |||||
| int status = RET_OK; | |||||
| for (int i = 0; i < proto.layer_size(); i++) { | for (int i = 0; i < proto.layer_size(); i++) { | ||||
| auto layer = proto.layer(i); | auto layer = proto.layer(i); | ||||
| @@ -222,38 +227,46 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto status = SetOpInputIdx(layer, op.get(), tensorCache); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!"; | |||||
| return status; | |||||
| } | |||||
| auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str()); | auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str()); | ||||
| if (nodeParser == nullptr) { | |||||
| MS_LOG(ERROR) << "Don't support type " << layer.type() << ". for caffe op " << layer.name(); | |||||
| return RET_NULL_PTR; | |||||
| if (nodeParser == nullptr || interrupt) { | |||||
| interrupt = true; | |||||
| if (nodeParser == nullptr) { | |||||
| NoSupportOp::GetInstance()->InsertOp(layer.type()); | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| } | |||||
| continue; | |||||
| } | } | ||||
| std::vector<schema::TensorT *> weightVec; | std::vector<schema::TensorT *> weightVec; | ||||
| status = nodeParser->Parse(layer, layerP, op.get(), &weightVec); | |||||
| if (status != RET_OK) { | |||||
| auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); | |||||
| if (status_node != RET_OK) { | |||||
| interrupt = true; | |||||
| MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; | MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; | ||||
| return status; | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| continue; | |||||
| } | } | ||||
| status_node = SetOpInputIdx(layer, op.get(), tensorCache); | |||||
| if (status_node != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!"; | |||||
| status = (status == RET_OK ? status_node : status); | |||||
| } | |||||
| SetWeightTensor(weightVec, op.get(), tensorCache); | SetWeightTensor(weightVec, op.get(), tensorCache); | ||||
| status = SetOpOutputIdx(layer, op.get(), tensorCache); | |||||
| if (status != RET_OK) { | |||||
| status_node = SetOpOutputIdx(layer, op.get(), tensorCache); | |||||
| if (status_node != RET_OK) { | |||||
| interrupt = true; | |||||
| MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!"; | MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!"; | ||||
| return status; | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| continue; | |||||
| } | } | ||||
| // op->fmkType = FmkType_CAFFE; | // op->fmkType = FmkType_CAFFE; | ||||
| subGraphDef->nodes.emplace_back(move(op)); | subGraphDef->nodes.emplace_back(move(op)); | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| return status; | |||||
| } | } | ||||
| STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { | STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { | ||||
| @@ -249,6 +249,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, | schema::CNodeT *dst_op, schema::TensorT *dst_tensor, | ||||
| TensorCache *tensor_cache, const QuantType &quantType) { | TensorCache *tensor_cache, const QuantType &quantType) { | ||||
| // change op_type() to name(), that is unique | // change op_type() to name(), that is unique | ||||
| static bool interrupt = false; | |||||
| dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | ||||
| dst_op->quantType = quantType; | dst_op->quantType = quantType; | ||||
| // dst_op->fmkType = FmkType_ONNX; | // dst_op->fmkType = FmkType_ONNX; | ||||
| @@ -256,15 +257,25 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| << onnx_node.input_size(); | << onnx_node.input_size(); | ||||
| // get the real op type | // get the real op type | ||||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); | SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); | ||||
| auto status = ParseOnnxNodeAttr(onnx_graph, onnx_node, onnx_node.op_type(), dst_op); | |||||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); | |||||
| if (node_parser == nullptr || interrupt) { | |||||
| interrupt = true; | |||||
| if (node_parser == nullptr) { | |||||
| NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); | |||||
| } | |||||
| return RET_NOT_FIND_OP; | |||||
| } | |||||
| auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "parser onnx node attr failed"; | |||||
| interrupt = true; | |||||
| MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; | |||||
| return status; | return status; | ||||
| } | } | ||||
| // set op input index | // set op input index | ||||
| std::vector<string> node_inputs; | std::vector<string> node_inputs; | ||||
| (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); | (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); | ||||
| if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { | if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { | ||||
| interrupt = true; | |||||
| MS_LOG(ERROR) << "SetOpInputIndex failed"; | MS_LOG(ERROR) << "SetOpInputIndex failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -273,6 +284,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); | (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); | ||||
| if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { | if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { | ||||
| interrupt = true; | |||||
| MS_LOG(ERROR) << "SetOpOutputIndex failed"; | MS_LOG(ERROR) << "SetOpOutputIndex failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -340,8 +352,7 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co | |||||
| const string &onnx_op_type, schema::CNodeT *dst_op) { | const string &onnx_op_type, schema::CNodeT *dst_op) { | ||||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); | auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); | ||||
| if (node_parser == nullptr) { | if (node_parser == nullptr) { | ||||
| MS_LOG(ERROR) << "not find " << onnx_op_type << ", node parser is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| return node_parser->Parse(onnx_graph, onnx_node, dst_op); | return node_parser->Parse(onnx_graph, onnx_node, dst_op); | ||||
| } | } | ||||
| @@ -503,32 +514,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con | |||||
| } | } | ||||
| // init op node input/output tensor, and dst_op attr | // init op node input/output tensor, and dst_op attr | ||||
| for (const auto &onnx_node : onnx_graph.node()) { | for (const auto &onnx_node : onnx_graph.node()) { | ||||
| int status_node = RET_OK; | |||||
| if (onnx_node.op_type() == "Constant") { | if (onnx_node.op_type() == "Constant") { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (onnx_node.op_type() == "Gemm") { | if (onnx_node.op_type() == "Gemm") { | ||||
| ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); | |||||
| if (status == RET_OK) { | |||||
| ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); | |||||
| } | |||||
| continue; | continue; | ||||
| } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { | } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { | ||||
| status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| if (status == RET_OK) { | |||||
| status_node = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); | |||||
| if (status_node != RET_OK) { | |||||
| MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status_node; | |||||
| status = (status == RET_OK ? status_node : status); | |||||
| } | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); | std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); | ||||
| std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>(); | std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>(); | ||||
| status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType); | |||||
| if (status_node != RET_OK) { | |||||
| status = (status == RET_OK ? status_node : status); | |||||
| continue; | |||||
| } | } | ||||
| dst_graph->nodes.emplace_back(std::move(dst_op)); | dst_graph->nodes.emplace_back(std::move(dst_op)); | ||||
| } | } | ||||
| if (status != RET_OK) { | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| for (auto &tensor : tensor_cache.GetCachedTensor()) { | |||||
| delete tensor; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| SetAllTensors(tensor_cache, dst_graph.get()); | SetAllTensors(tensor_cache, dst_graph.get()); | ||||
| dst_graph->name = GetModelName(modelFile); | dst_graph->name = GetModelName(modelFile); | ||||
| return dst_graph.release(); | return dst_graph.release(); | ||||
| @@ -300,6 +300,15 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Floor; | op->primitive->value.type = schema::PrimitiveType_Floor; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else if (std::strcmp(node_name, "Neg") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteNegParser"; | |||||
| auto attr = std::make_unique<schema::NegT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Neg; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | } | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | ||||
| @@ -415,6 +424,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); | |||||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | ||||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | ||||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | ||||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser()); | |||||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | ||||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | ||||
| @@ -157,6 +157,11 @@ class TfliteFloorParser : public TfliteSingleInputOpParser { | |||||
| TfliteFloorParser() : TfliteSingleInputOpParser() {} | TfliteFloorParser() : TfliteSingleInputOpParser() {} | ||||
| }; | }; | ||||
| class TfliteNegParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteNegParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteCompareOpParser : public TfliteNodeParser { | class TfliteCompareOpParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -98,6 +98,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | ||||
| const QuantType &quant_type, schema::MetaGraphT *sub_graph) { | const QuantType &quant_type, schema::MetaGraphT *sub_graph) { | ||||
| int idx = 0; | int idx = 0; | ||||
| int status = RET_OK; | |||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | for (const auto &tflite_op : tflite_subgraph->operators) { | ||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | ||||
| auto op_type = GetMSOpType(tflite_op_type); | auto op_type = GetMSOpType(tflite_op_type); | ||||
| @@ -114,21 +115,24 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | ||||
| if (node_parser == nullptr) { | if (node_parser == nullptr) { | ||||
| MS_LOG(ERROR) << "cannot find node parser, opType: " << op_type.c_str(); | |||||
| return RET_NOT_FIND_OP; | |||||
| } | |||||
| int status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, | |||||
| &tensorsFormat, &tensorsIdMap); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | |||||
| return status; | |||||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| continue; | |||||
| } | } | ||||
| if (status == RET_OK) { | |||||
| status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, | |||||
| &tensorsFormat, &tensorsIdMap); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | |||||
| continue; | |||||
| } | |||||
| sub_graph->nodes.emplace_back(op.release()); | |||||
| opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); | |||||
| tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); | |||||
| sub_graph->nodes.emplace_back(op.release()); | |||||
| opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); | |||||
| tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | |||||
| return status; | |||||
| } | } | ||||
| STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | ||||
| @@ -162,8 +166,8 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> | |||||
| if (isConst) { | if (isConst) { | ||||
| int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); | int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "obtain const tensor failed"; | |||||
| return status; | |||||
| MS_LOG(ERROR) << "obtain const tensor failed"; | |||||
| return status; | |||||
| } | } | ||||
| } | } | ||||
| // set tensor attr | // set tensor attr | ||||
| @@ -118,6 +118,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{ | |||||
| {tflite::BuiltinOperator_UNPACK, "Unstack"}, | {tflite::BuiltinOperator_UNPACK, "Unstack"}, | ||||
| {tflite::BuiltinOperator_CUSTOM, "Custom"}, | {tflite::BuiltinOperator_CUSTOM, "Custom"}, | ||||
| {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, | {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, | ||||
| {tflite::BuiltinOperator_NEG, "Neg"}, | |||||
| }; | }; | ||||
| std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{ | std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{ | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include "backend/optimizer/common/pattern_engine.h" | #include "backend/optimizer/common/pattern_engine.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "tools/converter/return_code.h" | |||||
| #include "tools/converter/converter_context.h" | |||||
| using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>; | using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -73,7 +73,7 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||||
| size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); | size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); | ||||
| ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); | |||||
| ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); | |||||
| enum kTransFilterType { | enum kTransFilterType { | ||||
| kKCHW2HWCK, // 0 | kKCHW2HWCK, // 0 | ||||
| @@ -105,11 +105,11 @@ STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, | |||||
| STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | ||||
| int32_t filterH, int32_t filterW); | int32_t filterH, int32_t filterW); | ||||
| template<typename T> | |||||
| template <typename T> | |||||
| static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | ||||
| int32_t filterH, int32_t filterW); | int32_t filterH, int32_t filterW); | ||||
| template<typename T> | |||||
| template <typename T> | |||||
| static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); | static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); | ||||
| STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); | STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); | ||||
| @@ -18,7 +18,7 @@ | |||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ | #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ | ||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "tools/converter/return_code.h" | |||||
| #include "tools/converter/converter_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||