| @@ -228,6 +228,7 @@ class PrimitiveC { | |||
| bool infer_flag_ = true; | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| }; | |||
| using PrimitiveCPtr = std::shared_ptr<PrimitiveC>; | |||
| typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); | |||
| #endif | |||
| typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive); | |||
| @@ -203,6 +203,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| @@ -14,15 +14,15 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <utility> | |||
| #include "tools/anf_importer/anf_importer.h" | |||
| #include <utility> | |||
| #include "schema/model_generated.h" | |||
| #include "ir/dtype.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int AnfImporter::Import(const schema::QuantType &quantType) { | |||
| int AnfImporter::Import(const converter::Flags *flag) { | |||
| auto ret = ConverterConstTensor(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; | |||
| @@ -22,6 +22,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "base/base.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore::lite { | |||
| class AnfImporter { | |||
| @@ -30,7 +31,7 @@ class AnfImporter { | |||
| virtual ~AnfImporter() = default; | |||
| virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE); | |||
| virtual int Import(const converter::Flags *flag = nullptr); | |||
| virtual FuncGraphPtr GetResult() = 0; | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/anf_importer/import_from_protobuf.h" | |||
| #include "tools/anf_importer/import_from_mindir.h" | |||
| #include <unistd.h> | |||
| #include <map> | |||
| #include <memory> | |||
| @@ -36,6 +36,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/protobuf_utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "load_mindir/load_model.h" | |||
| using string = std::string; | |||
| using int32 = int32_t; | |||
| @@ -199,8 +200,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | |||
| int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | |||
| const onnx::ValueInfoProto &value_proto) { | |||
| int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node, | |||
| const onnx::ValueInfoProto &value_proto) { | |||
| if (node == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| @@ -274,8 +275,8 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto) { | |||
| int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto) { | |||
| if (outputFuncGraph == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| @@ -303,8 +304,8 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output | |||
| return status; | |||
| } | |||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -317,7 +318,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim | |||
| return true; | |||
| } | |||
| ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { | |||
| ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| switch (attr_tensor_type) { | |||
| case onnx::TensorProto_DataType_STRING: { | |||
| @@ -347,8 +348,8 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor | |||
| } | |||
| } | |||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -405,7 +406,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr | |||
| return ret == EOK; | |||
| } | |||
| bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||
| bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -460,8 +461,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con | |||
| return true; | |||
| } | |||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| std::vector<int> shape; | |||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | |||
| @@ -501,8 +502,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val | |||
| return true; | |||
| } | |||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | |||
| @@ -515,8 +516,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value | |||
| return true; | |||
| } | |||
| bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name, | |||
| const onnx::AttributeProto &attr_proto) { | |||
| bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name, | |||
| const onnx::AttributeProto &attr_proto) { | |||
| if (!attr_proto.has_ref_attr_name()) { | |||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | |||
| return false; | |||
| @@ -572,7 +573,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_ | |||
| return true; | |||
| } | |||
| bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||
| bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||
| const std::string &value_node_name = node_proto.output(0); | |||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | |||
| if (!attr_proto.has_ref_attr_name()) { | |||
| @@ -582,7 +583,7 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto & | |||
| return GetAttrValueForValueNode(value_node_name, attr_proto); | |||
| } | |||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProtobuf::GetAbstractForCNode( | |||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromMindir::GetAbstractForCNode( | |||
| const onnx::AttributeProto &attr_proto) { | |||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> kv; | |||
| for (int i = 0; i < attr_proto.tensors_size(); i++) { | |||
| @@ -601,9 +602,9 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt | |||
| return kv; | |||
| } | |||
| CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::NodeProto &node_proto, | |||
| const schema::QuantType &quantType) { | |||
| CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::NodeProto &node_proto, | |||
| const schema::QuantType &quantType) { | |||
| static bool interrupt = false; | |||
| if (outputFuncGraph == nullptr) { | |||
| MS_LOG(ERROR) << "output funcgraph is nullptr"; | |||
| @@ -685,8 +686,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||
| return cnode_ptr; | |||
| } | |||
| bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||
| bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||
| if (outputFuncGraph == nullptr || cnode_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "output funcgraph or cnode is nullptr"; | |||
| return false; | |||
| @@ -765,9 +766,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||
| return true; | |||
| } | |||
| int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| if (outputFuncGraph == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph is nullptr"; | |||
| return RET_NULL_PTR; | |||
| @@ -809,8 +809,8 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG | |||
| return status; | |||
| } | |||
| int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| if (outputFuncGraph == nullptr) { | |||
| MS_LOG(ERROR) << "fundgraph is nullptr"; | |||
| return RET_NULL_PTR; | |||
| @@ -833,7 +833,7 @@ int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | |||
| } | |||
| int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||
| int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||
| if (!model_proto.has_producer_name()) { | |||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | |||
| return RET_GRAPH_FILE_ERR; | |||
| @@ -854,7 +854,17 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||
| int AnfImporterFromMindir::Import(const converter::Flags *flag) { | |||
| onnx_model_ = ReadOnnxFromBinary(flag->modelFile); | |||
| if (onnx_model_ == nullptr) { | |||
| MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; | |||
| func_graph_ = LoadMindIR(flag->modelFile); | |||
| if (func_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>(); | |||
| if (dstGraph == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph is nullptr"; | |||
| @@ -865,10 +875,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | |||
| return status; | |||
| } | |||
| if (onnx_model_ == nullptr) { | |||
| MS_LOG(ERROR) << "onnx_model_ is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto quantType = flag->quantType; | |||
| const onnx::GraphProto &graphBuild = onnx_model_->graph(); | |||
| status = BuildFuncGraph(dstGraph, graphBuild, quantType); | |||
| if (status != RET_OK) { | |||
| @@ -881,25 +888,22 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||
| return RET_OK; | |||
| } | |||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | |||
| onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) { | |||
| auto onnx_model = new (std::nothrow) onnx::ModelProto; | |||
| if (onnx_model == nullptr) { | |||
| MS_LOG(ERROR) << "New onnx ModelProto failed!"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| if (RET_OK != ValidateFileStr(model_path, ".mindir")) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID); | |||
| return nullptr; | |||
| } | |||
| if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { | |||
| MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||
| MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model"; | |||
| return nullptr; | |||
| } | |||
| return onnx_model; | |||
| } | |||
| FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; } | |||
| FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; } | |||
| } // namespace mindspore::lite | |||
| @@ -29,18 +29,17 @@ | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore::lite { | |||
| class AnfImporterFromProtobuf : public AnfImporter { | |||
| class AnfImporterFromMindir : public AnfImporter { | |||
| public: | |||
| AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) | |||
| : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} | |||
| AnfImporterFromMindir() = default; | |||
| ~AnfImporterFromProtobuf() override = default; | |||
| ~AnfImporterFromMindir() override { delete onnx_model_; } | |||
| static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); | |||
| FuncGraphPtr GetResult() override; | |||
| int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override; | |||
| int Import(const converter::Flags *flag) override; | |||
| private: | |||
| int ConverterConstTensor() override { return RET_ERROR; }; | |||
| @@ -57,6 +57,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/identity_remove_pass.cc | |||
| ../optimizer/graph/infershape_pass.cc | |||
| ../optimizer/graph/slice_prepose_pass.cc | |||
| ../optimizer/graph/mindir_adjust_pass.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -29,6 +29,7 @@ | |||
| #include "tools/optimizer/fusion/batchmatmul_fusion.h" | |||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | |||
| @@ -61,6 +62,18 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | |||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | |||
| // mindir pre adjustment | |||
| if (config->fmk == converter::FmkType_MS) { | |||
| auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | |||
| mindir_adjust_pass->SetFmkType(config->fmk); | |||
| mindir_adjust_pass->SetQuantType(config->quantType); | |||
| if (!mindir_adjust_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "mindir adjust failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| } | |||
| // for now - trainning is not supporting fuse operations | |||
| if (!config->trainModel) { | |||
| // remove quantdtype when awaretraining | |||
| @@ -30,7 +30,7 @@ | |||
| #include "parser/onnx/onnx_converter.h" | |||
| #include "parser/tf/tf_converter.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/anf_importer/import_from_protobuf.h" | |||
| #include "tools/anf_importer/import_from_mindir.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| @@ -54,9 +54,7 @@ Converter::~Converter() { | |||
| class MindsporeImporter : public Converter { | |||
| public: | |||
| MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) { | |||
| modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph)); | |||
| } | |||
| MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); } | |||
| ~MindsporeImporter() override = default; | |||
| }; | |||
| @@ -66,7 +64,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| FuncGraphPtr graph = nullptr; | |||
| if (flag->fmk == converter::FmkType_MS) { | |||
| MS_ASSERT(nullptr != modelImporter); | |||
| int status = modelImporter->Import(flag->quantType); | |||
| int status = modelImporter->Import(flag); | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| graph = modelImporter->GetResult(); | |||
| } else { | |||
| @@ -127,15 +125,8 @@ int RunConverter(int argc, const char **argv) { | |||
| MetaGraphT *fb_graph = nullptr; | |||
| switch (flags->fmk) { | |||
| case FmkType::FmkType_MS: { | |||
| auto graph = std::make_shared<FuncGraph>(); | |||
| auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile); | |||
| if (onnx_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Read MINDIR from binary return nullptr"; | |||
| break; | |||
| } | |||
| MindsporeImporter mindsporeImporter(onnx_graph, graph); | |||
| MindsporeImporter mindsporeImporter; | |||
| fb_graph = mindsporeImporter.Convert(flags.get()); | |||
| delete onnx_graph; | |||
| break; | |||
| } | |||
| case FmkType::FmkType_CAFFE: { | |||
| @@ -26,22 +26,6 @@ namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr auto kAnfPrimitiveIndex = 0; | |||
| bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { | |||
| if (node == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); | |||
| } | |||
| bool IsRealKernel(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| @@ -136,6 +120,22 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive | |||
| } | |||
| } // namespace | |||
| bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { | |||
| if (node == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); | |||
| } | |||
| bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||
| @@ -34,6 +34,8 @@ using mindspore::lite::RET_OK; | |||
| using mindspore::lite::STATUS; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); | |||
| bool IsRealCNodeKernel(const AnfNodePtr &node); | |||
| bool IsGraphKernel(const AnfNodePtr &node); | |||
| @@ -0,0 +1,147 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/tensor.h" | |||
| using mindspore::lite::PrimitiveC; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| int MindirAdjustPass::ParameterNodeConvert(AnfNodePtr anf_node) { | |||
| if (!utils::isa<ParameterPtr>(anf_node)) { | |||
| MS_LOG(INFO) << "only parameter node need to convert tensor."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto param_node = anf_node->cast<ParameterPtr>(); | |||
| if (!param_node->has_default()) { | |||
| MS_LOG(INFO) << "this is graph input, don't need to convert."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| if (utils::isa<ParamValueLitePtr>(param_node->default_param())) { | |||
| MS_LOG(INFO) << "the tensor has been a paramvalueLite."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "fail to new a ParamValueLite."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| param_node->set_name(param_node->debug_info()->name()); | |||
| auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>(); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "the node is not a tensor::TensorPtr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| param_value->set_tensor_size(tensor_info->Size()); | |||
| param_value->set_tensor_type(tensor_info->data_type()); | |||
| auto tensor_shape = tensor_info->shape(); | |||
| std::vector<int> shape; | |||
| std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape), | |||
| [](int64_t value) { return static_cast<int>(value); }); | |||
| param_value->set_tensor_shape(shape); | |||
| auto *tensor = new (std::nothrow) lite::Tensor(tensor_info->data_type(), shape); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new a lite::tensor failed, get a nullptr."; | |||
| return lite::RET_MEMORY_FAILED; | |||
| } | |||
| auto *tensor_data_buf = tensor->MutableData(); | |||
| if (tensor_data_buf == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tensor data failed."; | |||
| delete tensor; | |||
| return lite::RET_MEMORY_FAILED; | |||
| } | |||
| if (memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s error."; | |||
| delete tensor; | |||
| return lite::RET_MEMORY_FAILED; | |||
| } | |||
| tensor->set_data(nullptr); | |||
| param_value->set_tensor_addr(tensor_data_buf); | |||
| param_node->set_default_param(param_value); | |||
| delete tensor; | |||
| return lite::RET_OK; | |||
| } | |||
| int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) { | |||
| if (!utils::isa<CNodePtr>(anf_node)) { | |||
| MS_LOG(INFO) << "only cnode need to convert primitive."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (cnode->inputs().empty() || cnode->input(0) == nullptr) { | |||
| MS_LOG(ERROR) << "the cnode is invalid."; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| auto value_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr || value_node->value() == nullptr) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| if (utils::isa<PrimitiveCPtr>(value_node->value())) { | |||
| MS_LOG(INFO) << "the value has been primitiveC."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto primitive = value_node->value()->cast<PrimitivePtr>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "the value is not primitive."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto inputs = cnode->inputs(); | |||
| inputs.erase(inputs.begin()); | |||
| if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { | |||
| auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| value_node->set_value(primitive_c); | |||
| } else { | |||
| auto primitiveT = std::make_unique<schema::PrimitiveT>(); | |||
| primitiveT->value.type = (CheckPrimitiveType(anf_node, prim::kPrimReturn) ? schema::PrimitiveType_Return | |||
| : schema::PrimitiveType_MakeTuple); | |||
| value_node->set_value(std::make_shared<PrimitiveC>(primitiveT.release())); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool MindirAdjustPass::Run(const FuncGraphPtr &graph) { | |||
| if (this->fmk_type_ != lite::converter::FmkType_MS) { | |||
| MS_LOG(INFO) << "The framework type of model should be mindir."; | |||
| return lite::RET_OK; | |||
| } | |||
| MS_ASSERT(graph != nullptr); | |||
| auto node_list = TopoSort(graph->get_return()); | |||
| int status = lite::RET_OK; | |||
| for (auto &node : node_list) { | |||
| if (utils::isa<ParameterPtr>(node)) { | |||
| status = ParameterNodeConvert(node); | |||
| } else if (utils::isa<CNodePtr>(node)) { | |||
| status = PrimitiveConvert(node); | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | |||
| #include <string> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "src/param_value_lite.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| using mindspore::schema::QuantType; | |||
| namespace mindspore::opt { | |||
| class MindirAdjustPass : public Pass { | |||
| public: | |||
| MindirAdjustPass() : Pass("mindir_adjust_pass") {} | |||
| ~MindirAdjustPass() override = default; | |||
| void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } | |||
| void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } | |||
| int ParameterNodeConvert(AnfNodePtr anf_node); | |||
| int PrimitiveConvert(AnfNodePtr anf_node); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| protected: | |||
| QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; | |||
| FmkType fmk_type_ = FmkType::FmkType_MS; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | |||