| @@ -228,6 +228,7 @@ class PrimitiveC { | |||||
| bool infer_flag_ = true; | bool infer_flag_ = true; | ||||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | ||||
| }; | }; | ||||
| using PrimitiveCPtr = std::shared_ptr<PrimitiveC>; | |||||
| typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); | typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); | ||||
| #endif | #endif | ||||
| typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive); | typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive); | ||||
| @@ -203,6 +203,7 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| ### train | ### train | ||||
| @@ -14,15 +14,15 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <utility> | |||||
| #include "tools/anf_importer/anf_importer.h" | #include "tools/anf_importer/anf_importer.h" | ||||
| #include <utility> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| int AnfImporter::Import(const schema::QuantType &quantType) { | |||||
| int AnfImporter::Import(const converter::Flags *flag) { | |||||
| auto ret = ConverterConstTensor(); | auto ret = ConverterConstTensor(); | ||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; | MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "tools/converter/converter_flags.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfImporter { | class AnfImporter { | ||||
| @@ -30,7 +31,7 @@ class AnfImporter { | |||||
| virtual ~AnfImporter() = default; | virtual ~AnfImporter() = default; | ||||
| virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE); | |||||
| virtual int Import(const converter::Flags *flag = nullptr); | |||||
| virtual FuncGraphPtr GetResult() = 0; | virtual FuncGraphPtr GetResult() = 0; | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/anf_importer/import_from_protobuf.h" | |||||
| #include "tools/anf_importer/import_from_mindir.h" | |||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -36,6 +36,7 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "tools/common/protobuf_utils.h" | #include "tools/common/protobuf_utils.h" | ||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "load_mindir/load_model.h" | |||||
| using string = std::string; | using string = std::string; | ||||
| using int32 = int32_t; | using int32 = int32_t; | ||||
| @@ -199,8 +200,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | ||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | ||||
| int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | |||||
| const onnx::ValueInfoProto &value_proto) { | |||||
| int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node, | |||||
| const onnx::ValueInfoProto &value_proto) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| @@ -274,8 +275,8 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto) { | |||||
| int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto) { | |||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| @@ -303,8 +304,8 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output | |||||
| return status; | return status; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -317,7 +318,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim | |||||
| return true; | return true; | ||||
| } | } | ||||
| ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { | |||||
| ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| switch (attr_tensor_type) { | switch (attr_tensor_type) { | ||||
| case onnx::TensorProto_DataType_STRING: { | case onnx::TensorProto_DataType_STRING: { | ||||
| @@ -347,8 +348,8 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor | |||||
| } | } | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -405,7 +406,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr | |||||
| return ret == EOK; | return ret == EOK; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||||
| bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -460,8 +461,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| std::vector<int> shape; | std::vector<int> shape; | ||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | for (int i = 0; i < attr_tensor.dims_size(); ++i) { | ||||
| @@ -501,8 +502,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | ||||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | ||||
| @@ -515,8 +516,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name, | |||||
| const onnx::AttributeProto &attr_proto) { | |||||
| bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name, | |||||
| const onnx::AttributeProto &attr_proto) { | |||||
| if (!attr_proto.has_ref_attr_name()) { | if (!attr_proto.has_ref_attr_name()) { | ||||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | ||||
| return false; | return false; | ||||
| @@ -572,7 +573,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_ | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||||
| bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||||
| const std::string &value_node_name = node_proto.output(0); | const std::string &value_node_name = node_proto.output(0); | ||||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | ||||
| if (!attr_proto.has_ref_attr_name()) { | if (!attr_proto.has_ref_attr_name()) { | ||||
| @@ -582,7 +583,7 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto & | |||||
| return GetAttrValueForValueNode(value_node_name, attr_proto); | return GetAttrValueForValueNode(value_node_name, attr_proto); | ||||
| } | } | ||||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProtobuf::GetAbstractForCNode( | |||||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromMindir::GetAbstractForCNode( | |||||
| const onnx::AttributeProto &attr_proto) { | const onnx::AttributeProto &attr_proto) { | ||||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> kv; | std::unordered_map<std::string, abstract::AbstractTensorPtr> kv; | ||||
| for (int i = 0; i < attr_proto.tensors_size(); i++) { | for (int i = 0; i < attr_proto.tensors_size(); i++) { | ||||
| @@ -601,9 +602,9 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt | |||||
| return kv; | return kv; | ||||
| } | } | ||||
| CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::NodeProto &node_proto, | |||||
| const schema::QuantType &quantType) { | |||||
| CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::NodeProto &node_proto, | |||||
| const schema::QuantType &quantType) { | |||||
| static bool interrupt = false; | static bool interrupt = false; | ||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| MS_LOG(ERROR) << "output funcgraph is nullptr"; | MS_LOG(ERROR) << "output funcgraph is nullptr"; | ||||
| @@ -685,8 +686,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||||
| return cnode_ptr; | return cnode_ptr; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||||
| bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||||
| if (outputFuncGraph == nullptr || cnode_ptr == nullptr) { | if (outputFuncGraph == nullptr || cnode_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "output funcgraph or cnode is nullptr"; | MS_LOG(ERROR) << "output funcgraph or cnode is nullptr"; | ||||
| return false; | return false; | ||||
| @@ -765,9 +766,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||||
| return true; | return true; | ||||
| } | } | ||||
| int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| MS_LOG(ERROR) << "funcgraph is nullptr"; | MS_LOG(ERROR) << "funcgraph is nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -809,8 +809,8 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG | |||||
| return status; | return status; | ||||
| } | } | ||||
| int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| if (outputFuncGraph == nullptr) { | if (outputFuncGraph == nullptr) { | ||||
| MS_LOG(ERROR) << "fundgraph is nullptr"; | MS_LOG(ERROR) << "fundgraph is nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -833,7 +833,7 @@ int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | ||||
| } | } | ||||
| int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||||
| int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||||
| if (!model_proto.has_producer_name()) { | if (!model_proto.has_producer_name()) { | ||||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | ||||
| return RET_GRAPH_FILE_ERR; | return RET_GRAPH_FILE_ERR; | ||||
| @@ -854,7 +854,17 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||||
| int AnfImporterFromMindir::Import(const converter::Flags *flag) { | |||||
| onnx_model_ = ReadOnnxFromBinary(flag->modelFile); | |||||
| if (onnx_model_ == nullptr) { | |||||
| MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; | |||||
| func_graph_ = LoadMindIR(flag->modelFile); | |||||
| if (func_graph_ == nullptr) { | |||||
| MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; | |||||
| return RET_GRAPH_FILE_ERR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>(); | FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>(); | ||||
| if (dstGraph == nullptr) { | if (dstGraph == nullptr) { | ||||
| MS_LOG(ERROR) << "funcgraph is nullptr"; | MS_LOG(ERROR) << "funcgraph is nullptr"; | ||||
| @@ -865,10 +875,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | ||||
| return status; | return status; | ||||
| } | } | ||||
| if (onnx_model_ == nullptr) { | |||||
| MS_LOG(ERROR) << "onnx_model_ is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto quantType = flag->quantType; | |||||
| const onnx::GraphProto &graphBuild = onnx_model_->graph(); | const onnx::GraphProto &graphBuild = onnx_model_->graph(); | ||||
| status = BuildFuncGraph(dstGraph, graphBuild, quantType); | status = BuildFuncGraph(dstGraph, graphBuild, quantType); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -881,25 +888,22 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | |||||
| onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) { | |||||
| auto onnx_model = new (std::nothrow) onnx::ModelProto; | auto onnx_model = new (std::nothrow) onnx::ModelProto; | ||||
| if (onnx_model == nullptr) { | if (onnx_model == nullptr) { | ||||
| MS_LOG(ERROR) << "New onnx ModelProto failed!"; | MS_LOG(ERROR) << "New onnx ModelProto failed!"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (RET_OK != ValidateFileStr(model_path, ".mindir")) { | if (RET_OK != ValidateFileStr(model_path, ".mindir")) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; | MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { | if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { | ||||
| MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return onnx_model; | return onnx_model; | ||||
| } | } | ||||
| FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; } | |||||
| FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -29,18 +29,17 @@ | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfImporterFromProtobuf : public AnfImporter { | |||||
| class AnfImporterFromMindir : public AnfImporter { | |||||
| public: | public: | ||||
| AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) | |||||
| : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} | |||||
| AnfImporterFromMindir() = default; | |||||
| ~AnfImporterFromProtobuf() override = default; | |||||
| ~AnfImporterFromMindir() override { delete onnx_model_; } | |||||
| static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); | static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); | ||||
| FuncGraphPtr GetResult() override; | FuncGraphPtr GetResult() override; | ||||
| int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override; | |||||
| int Import(const converter::Flags *flag) override; | |||||
| private: | private: | ||||
| int ConverterConstTensor() override { return RET_ERROR; }; | int ConverterConstTensor() override { return RET_ERROR; }; | ||||
| @@ -57,6 +57,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/graph/identity_remove_pass.cc | ../optimizer/graph/identity_remove_pass.cc | ||||
| ../optimizer/graph/infershape_pass.cc | ../optimizer/graph/infershape_pass.cc | ||||
| ../optimizer/graph/slice_prepose_pass.cc | ../optimizer/graph/slice_prepose_pass.cc | ||||
| ../optimizer/graph/mindir_adjust_pass.cc | |||||
| ) | ) | ||||
| add_subdirectory(../anf_importer anf_importer) | add_subdirectory(../anf_importer anf_importer) | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "tools/optimizer/fusion/batchmatmul_fusion.h" | #include "tools/optimizer/fusion/batchmatmul_fusion.h" | ||||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | #include "tools/optimizer/fusion/conv_conv_fusion.h" | ||||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||||
| #include "tools/optimizer/graph/identity_remove_pass.h" | #include "tools/optimizer/graph/identity_remove_pass.h" | ||||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | ||||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | #include "tools/optimizer/graph/weight_format_transform_pass.h" | ||||
| @@ -61,6 +62,18 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | ||||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | ||||
| // mindir pre adjustment | |||||
| if (config->fmk == converter::FmkType_MS) { | |||||
| auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | |||||
| mindir_adjust_pass->SetFmkType(config->fmk); | |||||
| mindir_adjust_pass->SetQuantType(config->quantType); | |||||
| if (!mindir_adjust_pass->Run(old_graph)) { | |||||
| MS_LOG(ERROR) << "mindir adjust failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| // for now - trainning is not supporting fuse operations | // for now - trainning is not supporting fuse operations | ||||
| if (!config->trainModel) { | if (!config->trainModel) { | ||||
| // remove quantdtype when awaretraining | // remove quantdtype when awaretraining | ||||
| @@ -30,7 +30,7 @@ | |||||
| #include "parser/onnx/onnx_converter.h" | #include "parser/onnx/onnx_converter.h" | ||||
| #include "parser/tf/tf_converter.h" | #include "parser/tf/tf_converter.h" | ||||
| #include "tools/anf_exporter/anf_exporter.h" | #include "tools/anf_exporter/anf_exporter.h" | ||||
| #include "tools/anf_importer/import_from_protobuf.h" | |||||
| #include "tools/anf_importer/import_from_mindir.h" | |||||
| #include "proto/onnx.pb.h" | #include "proto/onnx.pb.h" | ||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| @@ -54,9 +54,7 @@ Converter::~Converter() { | |||||
| class MindsporeImporter : public Converter { | class MindsporeImporter : public Converter { | ||||
| public: | public: | ||||
| MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) { | |||||
| modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph)); | |||||
| } | |||||
| MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); } | |||||
| ~MindsporeImporter() override = default; | ~MindsporeImporter() override = default; | ||||
| }; | }; | ||||
| @@ -66,7 +64,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| FuncGraphPtr graph = nullptr; | FuncGraphPtr graph = nullptr; | ||||
| if (flag->fmk == converter::FmkType_MS) { | if (flag->fmk == converter::FmkType_MS) { | ||||
| MS_ASSERT(nullptr != modelImporter); | MS_ASSERT(nullptr != modelImporter); | ||||
| int status = modelImporter->Import(flag->quantType); | |||||
| int status = modelImporter->Import(flag); | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| graph = modelImporter->GetResult(); | graph = modelImporter->GetResult(); | ||||
| } else { | } else { | ||||
| @@ -127,15 +125,8 @@ int RunConverter(int argc, const char **argv) { | |||||
| MetaGraphT *fb_graph = nullptr; | MetaGraphT *fb_graph = nullptr; | ||||
| switch (flags->fmk) { | switch (flags->fmk) { | ||||
| case FmkType::FmkType_MS: { | case FmkType::FmkType_MS: { | ||||
| auto graph = std::make_shared<FuncGraph>(); | |||||
| auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile); | |||||
| if (onnx_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Read MINDIR from binary return nullptr"; | |||||
| break; | |||||
| } | |||||
| MindsporeImporter mindsporeImporter(onnx_graph, graph); | |||||
| MindsporeImporter mindsporeImporter; | |||||
| fb_graph = mindsporeImporter.Convert(flags.get()); | fb_graph = mindsporeImporter.Convert(flags.get()); | ||||
| delete onnx_graph; | |||||
| break; | break; | ||||
| } | } | ||||
| case FmkType::FmkType_CAFFE: { | case FmkType::FmkType_CAFFE: { | ||||
| @@ -26,22 +26,6 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| constexpr auto kAnfPrimitiveIndex = 0; | constexpr auto kAnfPrimitiveIndex = 0; | ||||
| bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { | |||||
| if (node == nullptr) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return false; | |||||
| } | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return false; | |||||
| } | |||||
| return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); | |||||
| } | |||||
| bool IsRealKernel(const AnfNodePtr &node) { | bool IsRealKernel(const AnfNodePtr &node) { | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | ||||
| @@ -136,6 +120,22 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { | |||||
| if (node == nullptr) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return false; | |||||
| } | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return false; | |||||
| } | |||||
| return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); | |||||
| } | |||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b) { | bool AnfEqual(const BaseRef &a, const BaseRef &b) { | ||||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | ||||
| auto a_node = utils::cast<AnfNodePtr>(a); | auto a_node = utils::cast<AnfNodePtr>(a); | ||||
| @@ -34,6 +34,8 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::lite::STATUS; | using mindspore::lite::STATUS; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); | |||||
| bool IsRealCNodeKernel(const AnfNodePtr &node); | bool IsRealCNodeKernel(const AnfNodePtr &node); | ||||
| bool IsGraphKernel(const AnfNodePtr &node); | bool IsGraphKernel(const AnfNodePtr &node); | ||||
| @@ -0,0 +1,147 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||||
| #include <algorithm> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "tools/converter/quantizer/quant_cast.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/tensor.h" | |||||
| using mindspore::lite::PrimitiveC; | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| int MindirAdjustPass::ParameterNodeConvert(AnfNodePtr anf_node) { | |||||
| if (!utils::isa<ParameterPtr>(anf_node)) { | |||||
| MS_LOG(INFO) << "only parameter node need to convert tensor."; | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| auto param_node = anf_node->cast<ParameterPtr>(); | |||||
| if (!param_node->has_default()) { | |||||
| MS_LOG(INFO) << "this is graph input, don't need to convert."; | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| if (utils::isa<ParamValueLitePtr>(param_node->default_param())) { | |||||
| MS_LOG(INFO) << "the tensor has been a paramvalueLite."; | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "fail to new a ParamValueLite."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| param_node->set_name(param_node->debug_info()->name()); | |||||
| auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>(); | |||||
| if (tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "the node is not a tensor::TensorPtr."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| param_value->set_tensor_size(tensor_info->Size()); | |||||
| param_value->set_tensor_type(tensor_info->data_type()); | |||||
| auto tensor_shape = tensor_info->shape(); | |||||
| std::vector<int> shape; | |||||
| std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape), | |||||
| [](int64_t value) { return static_cast<int>(value); }); | |||||
| param_value->set_tensor_shape(shape); | |||||
| auto *tensor = new (std::nothrow) lite::Tensor(tensor_info->data_type(), shape); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new a lite::tensor failed, get a nullptr."; | |||||
| return lite::RET_MEMORY_FAILED; | |||||
| } | |||||
| auto *tensor_data_buf = tensor->MutableData(); | |||||
| if (tensor_data_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc tensor data failed."; | |||||
| delete tensor; | |||||
| return lite::RET_MEMORY_FAILED; | |||||
| } | |||||
| if (memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s error."; | |||||
| delete tensor; | |||||
| return lite::RET_MEMORY_FAILED; | |||||
| } | |||||
| tensor->set_data(nullptr); | |||||
| param_value->set_tensor_addr(tensor_data_buf); | |||||
| param_node->set_default_param(param_value); | |||||
| delete tensor; | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) { | |||||
| if (!utils::isa<CNodePtr>(anf_node)) { | |||||
| MS_LOG(INFO) << "only cnode need to convert primitive."; | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if (cnode->inputs().empty() || cnode->input(0) == nullptr) { | |||||
| MS_LOG(ERROR) << "the cnode is invalid."; | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| auto value_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| if (value_node == nullptr || value_node->value() == nullptr) { | |||||
| MS_LOG(ERROR) << "value node is invalid."; | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| if (utils::isa<PrimitiveCPtr>(value_node->value())) { | |||||
| MS_LOG(INFO) << "the value has been primitiveC."; | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| auto primitive = value_node->value()->cast<PrimitivePtr>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "the value is not primitive."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| auto inputs = cnode->inputs(); | |||||
| inputs.erase(inputs.begin()); | |||||
| if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { | |||||
| auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| value_node->set_value(primitive_c); | |||||
| } else { | |||||
| auto primitiveT = std::make_unique<schema::PrimitiveT>(); | |||||
| primitiveT->value.type = (CheckPrimitiveType(anf_node, prim::kPrimReturn) ? schema::PrimitiveType_Return | |||||
| : schema::PrimitiveType_MakeTuple); | |||||
| value_node->set_value(std::make_shared<PrimitiveC>(primitiveT.release())); | |||||
| } | |||||
| return lite::RET_OK; | |||||
| } | |||||
| bool MindirAdjustPass::Run(const FuncGraphPtr &graph) { | |||||
| if (this->fmk_type_ != lite::converter::FmkType_MS) { | |||||
| MS_LOG(INFO) << "The framework type of model should be mindir."; | |||||
| return lite::RET_OK; | |||||
| } | |||||
| MS_ASSERT(graph != nullptr); | |||||
| auto node_list = TopoSort(graph->get_return()); | |||||
| int status = lite::RET_OK; | |||||
| for (auto &node : node_list) { | |||||
| if (utils::isa<ParameterPtr>(node)) { | |||||
| status = ParameterNodeConvert(node); | |||||
| } else if (utils::isa<CNodePtr>(node)) { | |||||
| status = PrimitiveConvert(node); | |||||
| } | |||||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "tools/converter/converter_flags.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| using mindspore::lite::converter::FmkType; | |||||
| using mindspore::schema::QuantType; | |||||
| namespace mindspore::opt { | |||||
| class MindirAdjustPass : public Pass { | |||||
| public: | |||||
| MindirAdjustPass() : Pass("mindir_adjust_pass") {} | |||||
| ~MindirAdjustPass() override = default; | |||||
| void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } | |||||
| void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } | |||||
| int ParameterNodeConvert(AnfNodePtr anf_node); | |||||
| int PrimitiveConvert(AnfNodePtr anf_node); | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| protected: | |||||
| QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; | |||||
| FmkType fmk_type_ = FmkType::FmkType_MS; | |||||
| }; | |||||
| } // namespace mindspore::opt | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | |||||