From: @hangangqiang Reviewed-by: @zhanghaibo5 Signed-off-by:tags/v1.1.1
| @@ -254,12 +254,12 @@ if (ENABLE_CONVERTER) | |||
| ${TEST_DIR}/st/converter_test.cc | |||
| ${TEST_DIR}/st/control_flow_test.cc | |||
| ${TEST_DIR}/st/sub_graph_test.cc | |||
| ${TEST_DIR}/common/import_from_meta_graphT.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/import_from_meta_graphT.cc | |||
| ) | |||
| endif() | |||
| @@ -301,6 +301,8 @@ endif() | |||
| if (PLATFORM_ARM) | |||
| target_link_libraries(lite-test log) | |||
| else() | |||
| target_link_libraries(lite-test ${SECUREC_LIBRARY} pthread) | |||
| endif() | |||
| if (SUPPORT_NPU) | |||
| @@ -310,7 +312,6 @@ endif () | |||
| if (ENABLE_CONVERTER) | |||
| add_dependencies(lite-test fbs_inner_src) | |||
| target_link_libraries(lite-test | |||
| anf_importer_mid | |||
| anf_exporter_mid | |||
| tflite_parser_mid | |||
| caffe_parser_mid | |||
| @@ -320,12 +321,10 @@ if (ENABLE_CONVERTER) | |||
| fusion_mid | |||
| quantizer_mid | |||
| proto_mid | |||
| pthread | |||
| mindspore::protobuf | |||
| mindspore::eigen | |||
| mindspore::json | |||
| -Wl,--whole-archive mindspore_core -Wl,--no-whole-archive | |||
| mindspore::glog | |||
| ${SECUREC_LIBRARY} | |||
| ) | |||
| endif() | |||
| @@ -0,0 +1,175 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "include/errorcode.h" | |||
| #include "test/common/import_from_meta_graphT.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore::lite { | |||
| AnfNodePtr AnfImporterFromMetaGraphT::GetNode(int tensor_id) { | |||
| auto n = nodes_.find(tensor_id); | |||
| if (n == nodes_.end()) { | |||
| return nullptr; | |||
| } | |||
| return n->second; | |||
| } | |||
| void AnfImporterFromMetaGraphT::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } | |||
| int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| MS_ASSERT(nullptr != meta_graph_); | |||
| MS_ASSERT(nullptr != func_graph_); | |||
| for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { | |||
| auto &tensor = meta_graph_->allTensors.at(i); | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (tensor->nodeType != schema::NodeType::NodeType_ValueNode) { | |||
| continue; | |||
| } | |||
| auto parameter = func_graph_->add_parameter(); | |||
| std::vector<int> shape(tensor->dims.size()); | |||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | |||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||
| auto type_ptr = TypeIdToType(type_id); | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| MS_ASSERT(nullptr != abstract_tensor); | |||
| parameter->set_abstract(abstract_tensor); | |||
| if (!tensor->name.empty()) { | |||
| parameter->set_name(tensor->name); | |||
| } else { | |||
| parameter->set_name("const-" + std::to_string(i)); | |||
| } | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| MS_ASSERT(nullptr != param_value); | |||
| param_value->set_tensor_shape(shape); | |||
| param_value->set_tensor_type(type_id); | |||
| param_value->set_format(tensor->format); | |||
| if (!tensor->data.empty()) { | |||
| auto size = tensor->data.size(); | |||
| char *tensor_data = new (std::nothrow) char[size]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "new char[] failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| auto ret = memcpy_s(tensor_data, size, tensor->data.data(), size); | |||
| if (EOK != ret) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| delete[] tensor_data; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| param_value->SetTensorData(tensor_data, size); | |||
| parameter->set_default_param(param_value); | |||
| } else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) == | |||
| meta_graph_->inputIndex.end()) { | |||
| parameter->set_default_param(param_value); | |||
| } | |||
| AddNode(i, parameter); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode) { | |||
| return nullptr; | |||
| } | |||
| abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTensor( | |||
| const std::unique_ptr<schema::TensorT> &tensor) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| std::vector<int> shape(tensor->dims.size()); | |||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | |||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||
| auto type_ptr = TypeIdToType(type_id); | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto ptr = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| MS_ASSERT(nullptr != ptr); | |||
| return ptr; | |||
| } | |||
| int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, | |||
| const CNodePtr &dst_cnode) { | |||
| return RET_ERROR; | |||
| } | |||
| int AnfImporterFromMetaGraphT::ConverterCNode() { | |||
| MS_ASSERT(nullptr != meta_graph_); | |||
| MS_ASSERT(nullptr != func_graph_); | |||
| for (const auto &cNode : meta_graph_->nodes) { | |||
| MS_ASSERT(nullptr != cNode); | |||
| auto anf_primitive = ConvertPrimitive(cNode); | |||
| if (anf_primitive == nullptr) { | |||
| MS_LOG(ERROR) << "cannot obtain anf primitive"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<AnfNodePtr> op_inputs = {anf_primitive}; | |||
| for (int j : cNode->inputIndex) { | |||
| auto node = GetNode(j); | |||
| if (nullptr == node) { | |||
| MS_LOG(ERROR) << "Can't find input node."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op_inputs.push_back(node); | |||
| } | |||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||
| MS_ASSERT(nullptr != new_cnode); | |||
| new_cnode->set_fullname_with_scope(cNode->name); | |||
| auto status = ConvertAbstract(cNode, new_cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertAbstract failed."; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromMetaGraphT::AddReturnCNode() { return RET_ERROR; } | |||
| FuncGraphPtr AnfImporterFromMetaGraphT::Fb2Anf(schema::MetaGraphT *meta_graph) { | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "meta_graph is null"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| AnfImporterFromMetaGraphT anfImporterFromMetaGraphT(meta_graph); | |||
| auto ret = anfImporterFromMetaGraphT.ConverterConstTensor(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; | |||
| return nullptr; | |||
| } | |||
| ret = anfImporterFromMetaGraphT.ConverterCNode(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "ConverterCNode failed " << ret; | |||
| return nullptr; | |||
| } | |||
| ret = anfImporterFromMetaGraphT.AddReturnCNode(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "AddReturnCNode failed " << ret; | |||
| return nullptr; | |||
| } | |||
| return anfImporterFromMetaGraphT.func_graph_; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -19,24 +19,26 @@ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/anf_importer/anf_importer.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore::lite { | |||
| class AnfImporterFromMetaGraphT : public AnfImporter { | |||
| class AnfImporterFromMetaGraphT { | |||
| public: | |||
| AnfImporterFromMetaGraphT(schema::MetaGraphT *meta_graph, FuncGraphPtr func_graph) | |||
| : meta_graph_(meta_graph), func_graph_(std::move(func_graph)) {} | |||
| virtual ~AnfImporterFromMetaGraphT() = default; | |||
| ~AnfImporterFromMetaGraphT() override = default; | |||
| FuncGraphPtr GetResult() override; | |||
| static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph); | |||
| private: | |||
| int ConverterConstTensor() override; | |||
| explicit AnfImporterFromMetaGraphT(schema::MetaGraphT *meta_graph) : meta_graph_(meta_graph) { | |||
| this->func_graph_ = std::make_shared<FuncGraph>(); | |||
| } | |||
| int ConverterConstTensor(); | |||
| int ConverterCNode() override; | |||
| int ConverterCNode(); | |||
| ValueNodePtr ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode); | |||
| @@ -44,9 +46,14 @@ class AnfImporterFromMetaGraphT : public AnfImporter { | |||
| int ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, const CNodePtr &dst_cnode); | |||
| int AddReturnCNode() override; | |||
| int AddReturnCNode(); | |||
| AnfNodePtr GetNode(int tensor_id); | |||
| void AddNode(int tensor_id, AnfNodePtr node); | |||
| private: | |||
| std::unordered_map<int, AnfNodePtr> nodes_; | |||
| schema::MetaGraphT *meta_graph_; | |||
| FuncGraphPtr func_graph_; | |||
| }; | |||
| @@ -17,18 +17,11 @@ | |||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | |||
| namespace mindspore { | |||
| schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) { | |||
| lite::TfliteModelParser parser; | |||
| meta_graph = parser.ParseToFb(model_path, weight_path, schema::QuantType_QUANT_NONE); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; | |||
| return nullptr; | |||
| } | |||
| return meta_graph; | |||
| schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const std::string &model_path, const std::string &weight_path) { | |||
| return nullptr; | |||
| } | |||
| void TestTfliteParser::TearDown() { free(meta_graph); } | |||
| @@ -21,12 +21,12 @@ | |||
| #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "import_from_meta_graphT.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "test/common/import_from_meta_graphT.h" | |||
| namespace mindspore { | |||
| class ConstantFoldingFusionTest : public mindspore::CommonTest { | |||
| @@ -370,7 +370,7 @@ MetaGraphTptr BuildSplitGraph() { | |||
| } // namespace | |||
| TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_AddFusion, new schema::AddFusionT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -383,7 +383,7 @@ TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestMixedConstantFold) { | |||
| auto meta_graph = BuildMixGraph(); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -396,7 +396,7 @@ TEST_F(ConstantFoldingFusionTest, TestMixedConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestSubConstantFold) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_SubFusion, new schema::SubFusionT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -409,7 +409,7 @@ TEST_F(ConstantFoldingFusionTest, TestSubConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestMulConstantFold) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_MulFusion, new schema::MulFusionT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -423,7 +423,7 @@ TEST_F(ConstantFoldingFusionTest, TestMulConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestTransposeConstantFold) { | |||
| auto transposeT = new schema::TransposeT; | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Transpose, transposeT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -470,7 +470,7 @@ TEST_F(ConstantFoldingFusionTest, TestStackConstantFold) { | |||
| auto stackT = new schema::StackT; | |||
| stackT->axis[0] = 1; | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Stack, stackT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -484,7 +484,7 @@ TEST_F(ConstantFoldingFusionTest, TestStackConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestSliceConstantFold) { | |||
| auto sliceT = new schema::SliceFusionT; | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_SliceFusion, sliceT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -498,7 +498,7 @@ TEST_F(ConstantFoldingFusionTest, TestSliceConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestShapeConstantFold) { | |||
| auto shapeT = new schema::ShapeT; | |||
| auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Shape, shapeT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -512,7 +512,7 @@ TEST_F(ConstantFoldingFusionTest, TestShapeConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestRsqrtConstantFold) { | |||
| auto rsqrtT = new schema::RsqrtT; | |||
| auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Rsqrt, rsqrtT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -526,7 +526,7 @@ TEST_F(ConstantFoldingFusionTest, TestRsqrtConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestReshapeConstantFold) { | |||
| auto reshapeT = new schema::ReshapeT; | |||
| auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Reshape, reshapeT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -543,7 +543,7 @@ TEST_F(ConstantFoldingFusionTest, TestRangeConstantFold) { | |||
| rangeT->start = 1; | |||
| rangeT->delta = 1; | |||
| auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Range, rangeT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -556,7 +556,7 @@ TEST_F(ConstantFoldingFusionTest, TestRangeConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestMatmulConstantFold) { | |||
| auto matmulT = new schema::MatMulT; | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_MatMul, matmulT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -570,7 +570,7 @@ TEST_F(ConstantFoldingFusionTest, TestMatmulConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestExpandDimsConstantFold) { | |||
| auto expandDimsT = new schema::ExpandDimsT; | |||
| auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_ExpandDims, expandDimsT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -584,7 +584,7 @@ TEST_F(ConstantFoldingFusionTest, TestExpandDimsConstantFold) { | |||
| TEST_F(ConstantFoldingFusionTest, TestConcatDimsConstantFold) { | |||
| auto concatT = new schema::ConcatT; | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Concat, concatT); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -600,7 +600,7 @@ TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) { | |||
| auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Cast, castT); | |||
| auto input_tensor = meta_graph->allTensors.at(0).get(); | |||
| input_tensor->dataType = kNumberTypeUInt8; | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -615,7 +615,7 @@ TEST_F(ConstantFoldingFusionTest, TestSplitConstantFold) { | |||
| auto meta_graph = BuildSplitGraph(); | |||
| auto input_tensor = meta_graph->allTensors.at(0).get(); | |||
| input_tensor->dataType = kNumberTypeFloat32; | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>("test", false); | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| @@ -21,11 +21,11 @@ | |||
| #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "import_from_meta_graphT.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "test/common/import_from_meta_graphT.h" | |||
| namespace mindspore { | |||
| class ConvActivationFusionTest : public mindspore::CommonTest { | |||
| @@ -136,7 +136,7 @@ MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, schema::ActivationType | |||
| } // namespace | |||
| TEST_F(ConvActivationFusionTest, TestConvReluNode) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -149,7 +149,7 @@ TEST_F(ConvActivationFusionTest, TestConvReluNode) { | |||
| TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU6); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -162,7 +162,7 @@ TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { | |||
| TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_LEAKY_RELU); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -21,11 +21,11 @@ | |||
| #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "import_from_meta_graphT.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "test/common/import_from_meta_graphT.h" | |||
| namespace mindspore { | |||
| class ConvBiasAddFusionTest : public mindspore::CommonTest { | |||
| @@ -145,7 +145,7 @@ MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, schema::PrimitiveType | |||
| } // namespace | |||
| TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_BiasAdd); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -156,7 +156,7 @@ TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { | |||
| TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_AddFusion); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -166,7 +166,7 @@ TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { | |||
| TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_MatMul); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -21,11 +21,11 @@ | |||
| #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "import_from_meta_graphT.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "test/common/import_from_meta_graphT.h" | |||
| namespace mindspore { | |||
| class ConvBNFusionTest : public mindspore::CommonTest { | |||
| @@ -262,7 +262,7 @@ MetaGraphTptr BuildTFGraph(schema::PrimitiveType conv_type) { | |||
| } // namespace | |||
| TEST_F(ConvBNFusionTest, TestConvAddNode) { | |||
| auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2DFusion); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -272,7 +272,7 @@ TEST_F(ConvBNFusionTest, TestConvAddNode) { | |||
| TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) { | |||
| auto meta_graph = BuildTFGraph(schema::PrimitiveType_Conv2DFusion); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -21,11 +21,11 @@ | |||
| #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "import_from_meta_graphT.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "test/common/import_from_meta_graphT.h" | |||
| namespace mindspore { | |||
| class ConvScaleFusionTest : public mindspore::CommonTest { | |||
| @@ -187,7 +187,7 @@ MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, bool conv_with_bias) { | |||
| } // namespace | |||
| TEST_F(ConvScaleFusionTest, TestConvScaleNode) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, true); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -198,7 +198,7 @@ TEST_F(ConvScaleFusionTest, TestConvScaleNode) { | |||
| TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) { | |||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, false); | |||
| auto func_graph = lite::Fb2Anf(meta_graph.get()); | |||
| auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); | |||
| auto anf_transform = new lite::AnfTransform(); | |||
| auto new_graph = anf_transform->Transform(func_graph); | |||
| ASSERT_NE(nullptr, new_graph); | |||
| @@ -1,302 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "include/errorcode.h" | |||
| #include "import_from_meta_graphT.h" | |||
| namespace mindspore::lite { | |||
| int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| MS_ASSERT(nullptr != meta_graph_); | |||
| MS_ASSERT(nullptr != func_graph_); | |||
| for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { | |||
| auto &tensor = meta_graph_->allTensors.at(i); | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (tensor->nodeType != schema::NodeType::NodeType_ValueNode) { | |||
| continue; | |||
| } | |||
| auto parameter = func_graph_->add_parameter(); | |||
| std::vector<int> shape(tensor->dims.size()); | |||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | |||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||
| auto type_ptr = TypeIdToType(type_id); | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| MS_ASSERT(nullptr != abstract_tensor); | |||
| parameter->set_abstract(abstract_tensor); | |||
| if (!tensor->name.empty()) { | |||
| parameter->set_name(tensor->name); | |||
| } else { | |||
| parameter->set_name("const-" + std::to_string(i)); | |||
| } | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| MS_ASSERT(nullptr != param_value); | |||
| param_value->set_tensor_shape(shape); | |||
| param_value->set_tensor_type(type_id); | |||
| param_value->set_format(tensor->format); | |||
| if (!tensor->data.empty()) { | |||
| auto size = tensor->data.size(); | |||
| char *tensor_data = new (std::nothrow) char[size]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "new char[] failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| auto ret = memcpy_s(tensor_data, size, tensor->data.data(), size); | |||
| if (EOK != ret) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| delete[] tensor_data; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| param_value->SetTensorData(tensor_data, size); | |||
| parameter->set_default_param(param_value); | |||
| } else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) == | |||
| meta_graph_->inputIndex.end()) { | |||
| parameter->set_default_param(param_value); | |||
| } | |||
| AddNode(i, parameter); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode) { | |||
| // MS_ASSERT(nullptr != meta_graph_); | |||
| // MS_ASSERT(nullptr != cNode); | |||
| // auto primitiveCValue = PrimitiveC::Create(cNode->primitive.release()); | |||
| // if (primitiveCValue == nullptr) { | |||
| // MS_LOG(ERROR) << "fail to convert primitive"; | |||
| // return nullptr; | |||
| // } | |||
| // cNode->primitive = nullptr; | |||
| // // add quant parameter | |||
| // auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| // for (auto index : cNode->inputIndex) { | |||
| // if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||
| // std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||
| // std::transform( | |||
| // meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||
| // quant_params.begin(), | |||
| // [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||
| // quant_params_holder->AddInputQuantParam(quant_params); | |||
| // } else { | |||
| // std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| // quant_params_holder->AddInputQuantParam(notinited_quant_params); | |||
| // } | |||
| // } | |||
| // for (auto index : cNode->outputIndex) { | |||
| // if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||
| // std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||
| // std::transform( | |||
| // meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||
| // quant_params.begin(), | |||
| // [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||
| // quant_params_holder->AddOutputQuantParam(quant_params); | |||
| // } else { | |||
| // std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| // quant_params_holder->AddOutputQuantParam(notinited_quant_params); | |||
| // } | |||
| // } | |||
| // primitiveCValue->AddAttr("quant_params", quant_params_holder); | |||
| // auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveCValue)); | |||
| // return value_node; | |||
| return nullptr; | |||
| } | |||
| abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTensor( | |||
| const std::unique_ptr<schema::TensorT> &tensor) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| std::vector<int> shape(tensor->dims.size()); | |||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | |||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||
| auto type_ptr = TypeIdToType(type_id); | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto ptr = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| MS_ASSERT(nullptr != ptr); | |||
| return ptr; | |||
| } | |||
| int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, | |||
| const CNodePtr &dst_cnode) { | |||
| // MS_ASSERT(nullptr != meta_graph_); | |||
| // MS_ASSERT(nullptr != src_cnode); | |||
| // MS_ASSERT(nullptr != dst_cnode); | |||
| // std::vector<uint32_t> out_tensor_ids = src_cnode->outputIndex; | |||
| // if (out_tensor_ids.size() == 1) { | |||
| // auto out_tensor_id = out_tensor_ids.front(); | |||
| // MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); | |||
| // auto &tensor = meta_graph_->allTensors.at(out_tensor_id); | |||
| // MS_ASSERT(nullptr != tensor); | |||
| // dst_cnode->set_abstract(ConvertTensorToAbstractTensor(tensor)); | |||
| // AddNode(out_tensor_id, dst_cnode); | |||
| // } else { | |||
| // AbstractBasePtrList abstract_list; | |||
| // for (size_t i = 0; i < out_tensor_ids.size(); i++) { | |||
| // auto out_tensor_id = out_tensor_ids.at(i); | |||
| // MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); | |||
| // auto &tensor = meta_graph_->allTensors.at(out_tensor_id); | |||
| // MS_ASSERT(nullptr != tensor); | |||
| // abstract_list.emplace_back(ConvertTensorToAbstractTensor(tensor)); | |||
| // auto tuple_get_item_prim_ptr = GetTupleGetItemPrim(); | |||
| // if (tuple_get_item_prim_ptr == nullptr) { | |||
| // MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||
| // auto get_item_value = NewValueNode(MakeValue<int>(i)); | |||
| // if (tuple_get_item_prim == nullptr || get_item_value == nullptr) { | |||
| // MS_LOG(ERROR) << "NewValueNode is nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value}; | |||
| // CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); | |||
| // if (get_item_cnode == nullptr) { | |||
| // MS_LOG(ERROR) << "NewCNode is nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // get_item_cnode->set_fullname_with_scope(src_cnode->name + "_getitem_" + std::to_string(i)); | |||
| // AddNode(out_tensor_id, get_item_cnode); | |||
| // } | |||
| // dst_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| // } | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromMetaGraphT::ConverterCNode() { | |||
| MS_ASSERT(nullptr != meta_graph_); | |||
| MS_ASSERT(nullptr != func_graph_); | |||
| for (const auto &cNode : meta_graph_->nodes) { | |||
| MS_ASSERT(nullptr != cNode); | |||
| auto anf_primitive = ConvertPrimitive(cNode); | |||
| if (anf_primitive == nullptr) { | |||
| MS_LOG(ERROR) << "cannot obtain anf primitive"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<AnfNodePtr> op_inputs = {anf_primitive}; | |||
| for (int j : cNode->inputIndex) { | |||
| auto node = GetNode(j); | |||
| if (nullptr == node) { | |||
| MS_LOG(ERROR) << "Can't find input node."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op_inputs.push_back(node); | |||
| } | |||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||
| MS_ASSERT(nullptr != new_cnode); | |||
| new_cnode->set_fullname_with_scope(cNode->name); | |||
| auto status = ConvertAbstract(cNode, new_cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertAbstract failed."; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||
| // MS_ASSERT(nullptr != meta_graph_); | |||
| // MS_ASSERT(nullptr != func_graph_); | |||
| // if (meta_graph_->outputIndex.size() > 1) { | |||
| // std::vector<AnfNodePtr> make_tuple_inputs; | |||
| // auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||
| // if (make_tuple_prim_ptr == nullptr) { | |||
| // MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||
| // make_tuple_inputs.emplace_back(make_tuple_prim); | |||
| // for (auto tensor_id : meta_graph_->outputIndex) { | |||
| // auto cNode = GetNode(tensor_id); | |||
| // if (nullptr == cNode) { | |||
| // MS_LOG(ERROR) << "Can't find input node."; | |||
| // return RET_ERROR; | |||
| // } | |||
| // make_tuple_inputs.emplace_back(cNode); | |||
| // } | |||
| // auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); | |||
| // if (make_tuple_cnode == nullptr) { | |||
| // MS_LOG(ERROR) << "NewCNode is nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||
| // std::vector<AnfNodePtr> op_inputs; | |||
| // auto return_prim_ptr = GetReturnPrim(); | |||
| // if (return_prim_ptr == nullptr) { | |||
| // MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // auto value_node = NewValueNode(return_prim_ptr); | |||
| // op_inputs.emplace_back(value_node); | |||
| // op_inputs.emplace_back(make_tuple_cnode); | |||
| // auto cnode = func_graph_->NewCNode(op_inputs); | |||
| // MS_ASSERT(nullptr != cnode); | |||
| // cnode->set_fullname_with_scope("return"); | |||
| // func_graph_->set_return(cnode); | |||
| // } else { | |||
| // auto return_prim_ptr = GetReturnPrim(); | |||
| // if (return_prim_ptr == nullptr) { | |||
| // MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // auto value_node = NewValueNode(return_prim_ptr); | |||
| // std::vector<AnfNodePtr> op_inputs{value_node}; | |||
| // auto cnode = GetNode(meta_graph_->outputIndex.front()); | |||
| // if (nullptr == cnode) { | |||
| // MS_LOG(ERROR) << "Can't find input node."; | |||
| // return RET_ERROR; | |||
| // } | |||
| // op_inputs.emplace_back(cnode); | |||
| // auto return_cnode = func_graph_->NewCNode(op_inputs); | |||
| // if (return_cnode == nullptr) { | |||
| // MS_LOG(ERROR) << "NewCNode is nullptr"; | |||
| // return RET_NULL_PTR; | |||
| // } | |||
| // return_cnode->set_fullname_with_scope("return"); | |||
| // func_graph_->set_return(return_cnode); | |||
| // } | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } | |||
| FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "meta_graph is null"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| auto func_graph = std::make_shared<FuncGraph>(); | |||
| AnfImporterFromMetaGraphT importer(meta_graph, func_graph); | |||
| auto status = importer.Import(); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| return func_graph; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,11 +0,0 @@ | |||
| file(GLOB ANF_IMPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| *.cc | |||
| ) | |||
| set_property(SOURCE ${ANF_IMPORTER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| add_library(anf_importer_mid OBJECT | |||
| ${ANF_IMPORTER_SRC_LIST} | |||
| ) | |||
| add_dependencies(anf_importer_mid proto_mid) | |||
| add_dependencies(anf_importer_mid fbs_src) | |||
| add_dependencies(anf_importer_mid fbs_inner_src) | |||
| @@ -1,54 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/anf_importer/anf_importer.h" | |||
| #include <utility> | |||
| #include "schema/model_generated.h" | |||
| #include "ir/dtype.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int AnfImporter::Import(const converter::Flags *flag) { | |||
| auto ret = ConverterConstTensor(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; | |||
| return ret; | |||
| } | |||
| ret = ConverterCNode(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "ConverterCNode failed " << ret; | |||
| return ret; | |||
| } | |||
| ret = AddReturnCNode(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "AddReturnCNode failed " << ret; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| AnfNodePtr AnfImporter::GetNode(int tensor_id) { | |||
| auto n = nodes_.find(tensor_id); | |||
| if (n == nodes_.end()) { | |||
| return nullptr; | |||
| } | |||
| return n->second; | |||
| } | |||
| void AnfImporter::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,55 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| #include <unordered_map> | |||
| #include "ir/func_graph.h" | |||
| #include "ir/anf.h" | |||
| #include "base/base.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore::lite { | |||
| class AnfImporter { | |||
| public: | |||
| AnfImporter() = default; | |||
| virtual ~AnfImporter() = default; | |||
| virtual int Import(const converter::Flags *flag = nullptr); | |||
| virtual FuncGraphPtr GetResult() = 0; | |||
| protected: | |||
| // convert const tensor into parameter and save in nodes_ | |||
| virtual int ConverterConstTensor() = 0; | |||
| // convert other node into cnode and save in nodes_ | |||
| virtual int ConverterCNode() = 0; | |||
| virtual int AddReturnCNode() = 0; | |||
| AnfNodePtr GetNode(int tensor_id); | |||
| void AddNode(int tensor_id, AnfNodePtr node); | |||
| protected: | |||
| std::unordered_map<int, AnfNodePtr> nodes_; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ | |||
| @@ -1,903 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/anf_importer/import_from_mindir.h" | |||
| #include <unistd.h> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "ops/make_tuple.h" | |||
| #include "ops/return.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "include/errorcode.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "securec/include/securec.h" | |||
| #include "src/tensor.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/protobuf_utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "load_mindir/load_model.h" | |||
| using string = std::string; | |||
| using int32 = int32_t; | |||
| using int64 = int64_t; | |||
| using uint64 = uint64_t; | |||
| namespace mindspore::lite { | |||
| static constexpr char kConstantValueNode[] = "Constant"; | |||
| enum ParseForm : int { | |||
| FORM_PARSE_TYPE = 0, | |||
| FORM_PARSE_SCALAR = 1, | |||
| FORM_PARSE_TENSOR = 2, | |||
| }; | |||
| static std::map<std::string, ParseForm> kParseTypeSwitchMap{ | |||
| {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; | |||
| static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | |||
| {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, | |||
| {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, | |||
| {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, | |||
| {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, | |||
| {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, | |||
| {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, | |||
| {onnx::TensorProto_DataType_STRING, kObjectTypeString}, | |||
| }; | |||
| std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name, | |||
| const std::unordered_map<string, ValuePtr> &kv) { | |||
| std::string str = attr_name; | |||
| auto replace = [&](const string &orgStr, const string &newStr) { | |||
| std::string::size_type pos(0); | |||
| while ((pos = str.find(orgStr)) != std::string::npos) { | |||
| str.replace(pos, orgStr.length(), newStr); | |||
| } | |||
| return str; | |||
| }; | |||
| // remove "scalar:" | |||
| str = replace("scalar:", ""); | |||
| // remove "Tuple" | |||
| str = replace("Tuple", ""); | |||
| // remove "List" | |||
| str = replace("List", ""); | |||
| std::stack<std::string> rules; | |||
| std::stack<ValuePtr> value; | |||
| int num = 0, count = 0; | |||
| for (size_t i = 0; i < str.length(); i++) { | |||
| if (str[i] == '[') { | |||
| rules.push("["); | |||
| } else if (str[i] == ']') { | |||
| // rules | |||
| std::vector<ValuePtr> vec; | |||
| while (rules.top() != "[") { | |||
| rules.pop(); | |||
| vec.push_back(value.top()); | |||
| value.pop(); | |||
| } | |||
| // pop "[" | |||
| rules.pop(); | |||
| // make tuple for names | |||
| std::string res = "dummy"; | |||
| // make tuple for values | |||
| reverse(vec.begin(), vec.end()); | |||
| auto vt = std::make_shared<ValueTuple>(vec); | |||
| if (rules.empty() && value.empty()) { | |||
| return vt; | |||
| } | |||
| rules.push(res); | |||
| value.push(vt); | |||
| } else if (str[i] == ',') { | |||
| continue; | |||
| } else { | |||
| count++; | |||
| if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { | |||
| auto value_name = str.substr(i - count + 1, count); | |||
| value.push(kv.at(value_name)); | |||
| rules.push(value_name); | |||
| count = 0; | |||
| num++; | |||
| } | |||
| } | |||
| } | |||
| return {}; | |||
| } | |||
| std::shared_ptr<abstract::AbstractTuple> ParserAttrShape( | |||
| const std::string &attr_name, const std::unordered_map<string, abstract::AbstractTensorPtr> &kv) { | |||
| std::string str = attr_name; | |||
| auto replace = [&](const string &orgStr, const string &newStr) { | |||
| std::string::size_type pos(0); | |||
| while ((pos = str.find(orgStr)) != std::string::npos) { | |||
| str.replace(pos, orgStr.length(), newStr); | |||
| } | |||
| return str; | |||
| }; | |||
| // remove "scalar:" | |||
| str = replace("shape:", ""); | |||
| // remove "Tuple" | |||
| str = replace("Tuple", ""); | |||
| // remove "List" | |||
| str = replace("List", ""); | |||
| std::stack<std::string> rules; | |||
| std::stack<abstract::AbstractBasePtr> value; | |||
| int num = 0, count = 0; | |||
| for (size_t i = 0; i < str.length(); i++) { | |||
| if (str[i] == '[') { | |||
| rules.push("["); | |||
| } else if (str[i] == ']') { | |||
| // rules | |||
| std::vector<abstract::AbstractBasePtr> vec; | |||
| while (rules.top() != "[") { | |||
| rules.pop(); | |||
| vec.push_back(value.top()); | |||
| value.pop(); | |||
| } | |||
| // pop "[" | |||
| rules.pop(); | |||
| // make tuple for names | |||
| std::string res = "dummy"; | |||
| // make tuple for values | |||
| reverse(vec.begin(), vec.end()); | |||
| auto vt = std::make_shared<abstract::AbstractTuple>(vec); | |||
| if (rules.empty() && value.empty()) { | |||
| return vt; | |||
| } | |||
| rules.push(res); | |||
| value.push(vt); | |||
| } else if (str[i] == ',') { | |||
| continue; | |||
| } else { | |||
| count++; | |||
| if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { | |||
| auto value_name = str.substr(i - count + 1, count); | |||
| value.push(kv.at(value_name)); | |||
| rules.push(value_name); | |||
| count = 0; | |||
| num++; | |||
| } | |||
| } | |||
| } | |||
| return {}; | |||
| } | |||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||
| ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ | |||
| if (attr_tensor.type##_data_size() == 1) { \ | |||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ | |||
| return MakeValue<valuetype>(value); \ | |||
| } else { \ | |||
| MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ | |||
| } \ | |||
| return {}; \ | |||
| } | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | |||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | |||
| int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node, | |||
| const onnx::ValueInfoProto &value_proto) { | |||
| if (node == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (!value_proto.has_type() || !value_proto.has_name()) { | |||
| MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| node->set_name(value_proto.name()); | |||
| const auto &type_proto = value_proto.type(); | |||
| if (!type_proto.has_tensor_type()) { | |||
| MS_LOG(ERROR) << "onnx TypeProto has no tensor_type! "; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); | |||
| if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { | |||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); | |||
| std::vector<int> shape; | |||
| for (int i = 0; i < tensor_shape.dim_size(); ++i) { | |||
| shape.push_back(tensor_shape.dim(i).dim_value()); | |||
| } | |||
| if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { | |||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| node->set_abstract(abstract_tensor); | |||
| if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | |||
| auto *tensor_info = new (std::nothrow) Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||
| if (tensor_info == nullptr) { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| tensor_info->MallocData(); | |||
| const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; | |||
| std::string initial_data = initialize_proto.raw_data(); | |||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->MutableData()); | |||
| if (tensor_data_buf == nullptr) { | |||
| delete tensor_info; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| tensor_info->set_data(nullptr); | |||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); | |||
| if (EOK != ret) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| delete tensor_data_buf; | |||
| delete tensor_info; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| delete tensor_info; | |||
| return RET_NULL_PTR; | |||
| } | |||
| param_value->SetTensorData(tensor_data_buf, tensor_info->Size()); | |||
| param_value->set_tensor_type(tensor_info->data_type()); | |||
| param_value->set_tensor_shape(tensor_info->shape()); | |||
| node->set_default_param(param_value); | |||
| delete tensor_info; | |||
| } | |||
| anfnode_build_map_[value_proto.name()] = node; | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto) { | |||
| if (outputFuncGraph == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); | |||
| for (int i = 0; i < importProto.initializer_size(); ++i) { | |||
| const onnx::TensorProto &initializer_proto = importProto.initializer(i); | |||
| if (!initializer_proto.has_name()) { | |||
| MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| default_para_map_[initializer_proto.name()] = initializer_proto; | |||
| } | |||
| int status = RET_OK; | |||
| MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); | |||
| for (int i = 0; i < importProto.input_size(); ++i) { | |||
| const onnx::ValueInfoProto &input_proto = importProto.input(i); | |||
| status = BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | |||
| break; | |||
| } | |||
| } | |||
| return status; | |||
| } | |||
| bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||
| MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; | |||
| return false; | |||
| } | |||
| prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||
| return true; | |||
| } | |||
| ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| switch (attr_tensor_type) { | |||
| case onnx::TensorProto_DataType_STRING: { | |||
| return ParseAttrInScalar_string_string(attr_tensor); | |||
| } | |||
| case onnx::TensorProto_DataType_INT32: { | |||
| return ParseAttrInScalar_int32_int32(attr_tensor); | |||
| } | |||
| case onnx::TensorProto_DataType_INT64: { | |||
| return ParseAttrInScalar_int64_int64(attr_tensor); | |||
| } | |||
| case onnx::TensorProto_DataType_UINT64: { | |||
| return ParseAttrInScalar_uint64_uint64(attr_tensor); | |||
| } | |||
| case onnx::TensorProto_DataType_FLOAT: { | |||
| return ParseAttrInScalar_float_float(attr_tensor); | |||
| } | |||
| case onnx::TensorProto_DataType_DOUBLE: { | |||
| return ParseAttrInScalar_double_double(attr_tensor); | |||
| } | |||
| case onnx::TensorProto_DataType_BOOL: { | |||
| return ParseAttrInScalar_int32_bool(attr_tensor); | |||
| } | |||
| default: | |||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||
| return {}; | |||
| } | |||
| } | |||
| bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||
| std::vector<int> shape; | |||
| auto ret = EOK; | |||
| if (attr_tensor.dims_size() != 0) { | |||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | |||
| shape.push_back(attr_tensor.dims(i)); | |||
| } | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| tensor::TensorPtr tensor_info = | |||
| std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape_vector); | |||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||
| ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); | |||
| if (EOK != ret) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| return false; | |||
| } | |||
| prim->set_attr(attr_name, MakeValue(tensor_info)); | |||
| } else { | |||
| if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) { | |||
| size_t data_size = sizeof(double); | |||
| double attr_value = 0.0; | |||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); | |||
| if (EOK != ret) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| return false; | |||
| } | |||
| prim->set_attr(attr_name, MakeValue<double>(attr_value)); | |||
| } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) { | |||
| size_t data_size = sizeof(int64_t); | |||
| int64_t attr_value = 0; | |||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); | |||
| if (EOK != ret) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| return false; | |||
| } | |||
| prim->set_attr(attr_name, MakeValue<int64_t>(attr_value)); | |||
| } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) { | |||
| size_t data_size = sizeof(bool); | |||
| bool attr_value = false; | |||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); | |||
| if (EOK != ret) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| return false; | |||
| } | |||
| prim->set_attr(attr_name, MakeValue<bool>(attr_value)); | |||
| } | |||
| } | |||
| return ret == EOK; | |||
| } | |||
| bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| const std::string &attr_name = attr_proto.name(); | |||
| if (!attr_proto.has_ref_attr_name()) { | |||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | |||
| return false; | |||
| } | |||
| const std::string &ref_attr_name = attr_proto.ref_attr_name(); | |||
| if (ref_attr_name.empty()) { | |||
| MS_LOG(ERROR) << "ref_attr_name is empty"; | |||
| return false; | |||
| } | |||
| string type = ""; | |||
| std::size_t pos(0); | |||
| if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { | |||
| type = ref_attr_name.substr(pos, string("scalar:").length() - 1); | |||
| } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { | |||
| type = ref_attr_name.substr(pos, string("type:").length() - 1); | |||
| } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { | |||
| type = ref_attr_name.substr(pos, string("tensor:").length() - 1); | |||
| } | |||
| std::unordered_map<std::string, ValuePtr> kv; | |||
| for (int i = 0; i < attr_proto.tensors_size(); i++) { | |||
| const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); | |||
| switch (kParseTypeSwitchMap[type]) { | |||
| case FORM_PARSE_TYPE: { | |||
| return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); | |||
| } | |||
| case FORM_PARSE_SCALAR: { | |||
| auto res = ObtainCNodeAttrInScalarForm(attr_tensor); | |||
| kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res)); | |||
| break; | |||
| } | |||
| case FORM_PARSE_TENSOR: { | |||
| return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); | |||
| } | |||
| default: | |||
| MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; | |||
| return false; | |||
| } | |||
| } | |||
| if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { | |||
| if (kv.size() == 1) { | |||
| auto iter = kv.begin(); | |||
| prim->AddAttr(attr_name, iter->second); | |||
| } else { | |||
| auto res = ParserScalarAttrValue(ref_attr_name, kv); | |||
| prim->AddAttr(attr_name, res); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| std::vector<int> shape; | |||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | |||
| shape.push_back(attr_tensor.dims(i)); | |||
| } | |||
| std::vector<int> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int>(value); }); | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| param_value->set_tensor_shape(shape_vector); | |||
| param_value->set_tensor_type(kDefaultValueSwitchMap[attr_tensor_type]); | |||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||
| auto tensor_data = new (std::nothrow) char[tensor_buf.size()]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "Tensor_data is nullptr"; | |||
| return false; | |||
| } | |||
| auto ret = memcpy_s(tensor_data, tensor_buf.size(), tensor_buf.data(), tensor_buf.size()); | |||
| if (ret != EOK) { | |||
| delete[] tensor_data; | |||
| MS_LOG(ERROR) << "Memcpy error: " << ret; | |||
| return false; | |||
| } | |||
| param_value->SetTensorData(tensor_data, tensor_buf.size()); | |||
| auto new_value_node = NewValueNode(MakeValue(param_value)); | |||
| if (new_value_node == nullptr) { | |||
| MS_LOG(ERROR) << "Make valuenode fail"; | |||
| return false; | |||
| } | |||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); | |||
| std::vector<int64_t> shape_vector_int64; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector_int64), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector_int64); | |||
| new_value_node->set_abstract(abstract_tensor); | |||
| anfnode_build_map_[value_node_name] = new_value_node; | |||
| return true; | |||
| } | |||
| bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||
| const onnx::TensorProto &attr_tensor) { | |||
| const int attr_tensor_type = attr_tensor.data_type(); | |||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | |||
| return false; | |||
| } | |||
| auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||
| abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||
| new_value_node->set_abstract(abs_type); | |||
| anfnode_build_map_[value_node_name] = new_value_node; | |||
| return true; | |||
| } | |||
| bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name, | |||
| const onnx::AttributeProto &attr_proto) { | |||
| if (!attr_proto.has_ref_attr_name()) { | |||
| MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; | |||
| return false; | |||
| } | |||
| const std::string &ref_attr_name = attr_proto.ref_attr_name(); | |||
| if (ref_attr_name.empty()) { | |||
| MS_LOG(ERROR) << "ref_attr_name is empty"; | |||
| return false; | |||
| } | |||
| string type = ""; | |||
| std::size_t pos(0); | |||
| if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { | |||
| type = ref_attr_name.substr(pos, string("scalar:").length() - 1); | |||
| } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { | |||
| type = ref_attr_name.substr(pos, string("type:").length() - 1); | |||
| } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { | |||
| type = ref_attr_name.substr(pos, string("tensor:").length() - 1); | |||
| } | |||
| std::unordered_map<std::string, ValuePtr> kv; | |||
| for (int i = 0; i < attr_proto.tensors_size(); i++) { | |||
| const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); | |||
| switch (kParseTypeSwitchMap[type]) { | |||
| case FORM_PARSE_TYPE: { | |||
| return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); | |||
| } | |||
| case FORM_PARSE_SCALAR: { | |||
| auto res = ObtainCNodeAttrInScalarForm(attr_tensor); | |||
| kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res)); | |||
| break; | |||
| } | |||
| case FORM_PARSE_TENSOR: { | |||
| return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); | |||
| } | |||
| default: | |||
| MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; | |||
| return false; | |||
| } | |||
| } | |||
| ValueNodePtr new_value_node; | |||
| if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { | |||
| if (kv.size() == 1) { | |||
| auto iter = kv.begin(); | |||
| new_value_node = NewValueNode(iter->second); | |||
| new_value_node->set_abstract(iter->second->ToAbstract()); | |||
| } else { | |||
| auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); | |||
| new_value_node = NewValueNode(value_ptr); | |||
| new_value_node->set_abstract(value_ptr->ToAbstract()); | |||
| } | |||
| anfnode_build_map_[value_node_name] = new_value_node; | |||
| } | |||
| return true; | |||
| } | |||
| bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||
| const std::string &value_node_name = node_proto.output(0); | |||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | |||
| if (!attr_proto.has_ref_attr_name()) { | |||
| MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; | |||
| return false; | |||
| } | |||
| return GetAttrValueForValueNode(value_node_name, attr_proto); | |||
| } | |||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromMindir::GetAbstractForCNode( | |||
| const onnx::AttributeProto &attr_proto) { | |||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> kv; | |||
| for (int i = 0; i < attr_proto.tensors_size(); i++) { | |||
| std::vector<int> shape; | |||
| const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); | |||
| for (int j = 0; j < attr_tensor.dims_size(); ++j) { | |||
| shape.push_back(attr_tensor.dims(j)); | |||
| } | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| kv.insert(std::pair<string, abstract::AbstractTensorPtr>(attr_tensor.name(), abstract_tensor)); | |||
| } | |||
| return kv; | |||
| } | |||
| CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::NodeProto &node_proto, | |||
| const schema::QuantType &quantType) { | |||
| static bool interrupt = false; | |||
| if (outputFuncGraph == nullptr) { | |||
| MS_LOG(ERROR) << "output funcgraph is nullptr"; | |||
| return nullptr; | |||
| } | |||
| if (!node_proto.has_op_type()) { | |||
| MS_LOG(ERROR) << "Get CNode op_type failed!"; | |||
| return nullptr; | |||
| } | |||
| const std::string &node_name = node_proto.output(0); | |||
| const std::string &fullname_with_scope = node_proto.domain(); | |||
| // const std::string &node_type = node_proto.op_type(); | |||
| PrimitivePtr prim; | |||
| // NOTE: can not find OpPrimCRegister | |||
| // auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap(); | |||
| // if (op_primc_fns.find(node_type) != op_primc_fns.end()) { | |||
| // prim = op_primc_fns[node_type](); | |||
| // } else { | |||
| // prim = std::make_shared<Primitive>(node_type); | |||
| // prim->set_instance_name(node_type); | |||
| // } | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> kv; | |||
| string shape_ref_attr_name; | |||
| for (int i = 0; i < node_proto.attribute_size(); ++i) { | |||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(i); | |||
| if (attr_proto.ref_attr_name().find("shape:") != string::npos) { | |||
| shape_ref_attr_name = attr_proto.ref_attr_name(); | |||
| kv = GetAbstractForCNode(attr_proto); | |||
| continue; | |||
| } | |||
| if (!GetAttrValueForCNode(prim, attr_proto)) { | |||
| MS_LOG(ERROR) << "Get CNode attr failed!"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| for (int i = 0; i < node_proto.input_size(); ++i) { | |||
| const std::string &input_name = node_proto.input(i); | |||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | |||
| if (!interrupt) { | |||
| MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; | |||
| interrupt = true; | |||
| } | |||
| inputs.push_back(nullptr); | |||
| } else { | |||
| inputs.push_back(anfnode_build_map_[input_name]); | |||
| } | |||
| } | |||
| CNodePtr cnode_ptr = outputFuncGraph->NewCNode(prim, inputs); | |||
| if (cnode_ptr == nullptr) { | |||
| interrupt = true; | |||
| MS_LOG(ERROR) << "funcgraph new cnode failed"; | |||
| return nullptr; | |||
| } | |||
| if (kv.empty()) { | |||
| AbstractBasePtrList elem; | |||
| for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { | |||
| elem.push_back(cnode_ptr->input(index)->abstract()); | |||
| } | |||
| cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||
| } else if (1 == kv.size()) { | |||
| auto iter = kv.begin(); | |||
| cnode_ptr->set_abstract(iter->second); | |||
| } else { | |||
| auto abstract = ParserAttrShape(shape_ref_attr_name, kv); | |||
| cnode_ptr->set_abstract(abstract); | |||
| } | |||
| cnode_ptr->set_fullname_with_scope(fullname_with_scope); | |||
| anfnode_build_map_[node_name] = cnode_ptr; | |||
| return cnode_ptr; | |||
| } | |||
| bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||
| if (outputFuncGraph == nullptr || cnode_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "output funcgraph or cnode is nullptr"; | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| if (importProto.output_size() > 1) { | |||
| inputs.clear(); | |||
| auto make_tuple_prim = std::make_shared<ops::MakeTuple>(); | |||
| inputs.push_back(NewValueNode(make_tuple_prim)); | |||
| AbstractBasePtrList elem; | |||
| for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { | |||
| const onnx::ValueInfoProto &output_node = importProto.output(out_size); | |||
| const std::string &out_tuple = output_node.name(); | |||
| inputs.push_back(anfnode_build_map_[out_tuple]); | |||
| if (anfnode_build_map_[out_tuple] == nullptr) { | |||
| MS_LOG(ERROR) << "AnfNode is nullptr"; | |||
| return false; | |||
| } | |||
| elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | |||
| } | |||
| auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | |||
| if (maketuple_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "maketuple_ptr is nullptr"; | |||
| return false; | |||
| } | |||
| maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||
| inputs.clear(); | |||
| auto return_prim = std::make_shared<ops::Return>(); | |||
| inputs.push_back(NewValueNode(return_prim)); | |||
| inputs.push_back(maketuple_ptr); | |||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||
| if (return_node == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph new cnode failed"; | |||
| return false; | |||
| } | |||
| outputFuncGraph->set_return(return_node); | |||
| MS_LOG(INFO) << "Construct funcgraph finined, all success."; | |||
| } else { | |||
| const onnx::ValueInfoProto &output_node = importProto.output(0); | |||
| const onnx::TypeProto &output_typeproto = output_node.type(); | |||
| int output_type = output_typeproto.tensor_type().elem_type(); | |||
| std::vector<int> output_shape; | |||
| for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { | |||
| output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); | |||
| } | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| inputs.clear(); | |||
| auto return_prim = std::make_shared<ops::Return>(); | |||
| inputs.push_back(NewValueNode(return_prim)); | |||
| inputs.push_back(cnode_ptr); | |||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||
| if (return_node == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph new cnode failed"; | |||
| return false; | |||
| } | |||
| return_node->set_abstract(abstract_tensor); | |||
| outputFuncGraph->set_return(return_node); | |||
| MS_LOG(INFO) << "Construct funcgraph finined, all success!"; | |||
| } | |||
| return true; | |||
| } | |||
| int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| if (outputFuncGraph == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | |||
| CNodePtr cnode_ptr = nullptr; | |||
| CNodePtr last_cnode_ptr = nullptr; | |||
| int status = RET_OK; | |||
| NoSupportOp::GetInstance()->SetFmkType("MINDIR"); | |||
| for (int i = 0; i < importProto.node_size(); ++i) { | |||
| const onnx::NodeProto &node_proto = importProto.node(i); | |||
| const std::string &node_type = node_proto.op_type(); | |||
| MS_LOG(INFO) << "parse op : " << node_type; | |||
| if (node_type == kConstantValueNode) { | |||
| if (status == RET_OK && !BuildValueNodeForFuncGraph(node_proto)) { | |||
| MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; | |||
| status = RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); | |||
| if (cnode_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | |||
| return RET_ERROR; | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode_ptr->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||
| status = RET_ERROR; | |||
| } | |||
| return status; | |||
| } | |||
| int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType) { | |||
| if (outputFuncGraph == nullptr) { | |||
| MS_LOG(ERROR) << "fundgraph is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | |||
| if (debug_info_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph's debug info is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (importProto.has_name()) { | |||
| debug_info_ptr->set_name(importProto.name()); | |||
| } else { | |||
| MS_LOG(INFO) << "FuncGraph under converting has not name!"; | |||
| } | |||
| auto status = ImportParametersForGraph(outputFuncGraph, importProto); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | |||
| } | |||
| int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||
| if (!model_proto.has_producer_name()) { | |||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| producer_name_ = model_proto.producer_name(); | |||
| if (!model_proto.has_model_version()) { | |||
| MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| model_version_ = model_proto.model_version(); | |||
| if (!model_proto.has_ir_version()) { | |||
| MS_LOG(ERROR) << "Parse model version from pb file failed!"; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| ir_version_ = model_proto.ir_version(); | |||
| return RET_OK; | |||
| } | |||
| int AnfImporterFromMindir::Import(const converter::Flags *flag) { | |||
| #if SUPPORT_TRAIN | |||
| func_graph_ = LoadMindIR(flag->modelFile); | |||
| if (func_graph_ != nullptr) { | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; | |||
| } | |||
| #endif | |||
| onnx_model_ = ReadOnnxFromBinary(flag->modelFile); | |||
| if (onnx_model_ == nullptr) { | |||
| MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; | |||
| func_graph_ = LoadMindIR(flag->modelFile); | |||
| if (func_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; | |||
| return RET_GRAPH_FILE_ERR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>(); | |||
| if (dstGraph == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| int status = ParseModelConfigureInfo(*onnx_model_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | |||
| return status; | |||
| } | |||
| auto quantType = flag->quantType; | |||
| const onnx::GraphProto &graphBuild = onnx_model_->graph(); | |||
| status = BuildFuncGraph(dstGraph, graphBuild, quantType); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Build funcgraph failed!"; | |||
| func_graph_ = nullptr; | |||
| return status; | |||
| } | |||
| func_graph_ = dstGraph; | |||
| MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; | |||
| return RET_OK; | |||
| } | |||
| onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) { | |||
| auto onnx_model = new (std::nothrow) onnx::ModelProto; | |||
| if (onnx_model == nullptr) { | |||
| MS_LOG(ERROR) << "New onnx ModelProto failed!"; | |||
| return nullptr; | |||
| } | |||
| if (RET_OK != ValidateFileStr(model_path, ".mindir")) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; | |||
| return nullptr; | |||
| } | |||
| if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { | |||
| MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model"; | |||
| return nullptr; | |||
| } | |||
| return onnx_model; | |||
| } | |||
| FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; } | |||
| } // namespace mindspore::lite | |||
| @@ -1,83 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "include/errorcode.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "tools/anf_importer/anf_importer.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore::lite { | |||
| class AnfImporterFromMindir : public AnfImporter { | |||
| public: | |||
| AnfImporterFromMindir() = default; | |||
| ~AnfImporterFromMindir() override { delete onnx_model_; } | |||
| static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); | |||
| FuncGraphPtr GetResult() override; | |||
| int Import(const converter::Flags *flag) override; | |||
| private: | |||
| int ConverterConstTensor() override { return RET_ERROR; }; | |||
| int ConverterCNode() override { return RET_ERROR; }; | |||
| int AddReturnCNode() override { return RET_ERROR; }; | |||
| int ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | |||
| int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType); | |||
| int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||
| int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const schema::QuantType &quantType); | |||
| int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | |||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | |||
| const schema::QuantType &quantType); | |||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const CNodePtr &cnode_ptr); | |||
| static bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); | |||
| static bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor); | |||
| static ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); | |||
| static bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor); | |||
| bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); | |||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||
| bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto); | |||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||
| static std::unordered_map<std::string, abstract::AbstractTensorPtr> GetAbstractForCNode( | |||
| const onnx::AttributeProto &attr_proto); | |||
| private: | |||
| std::string producer_name_; | |||
| int model_version_{}; | |||
| int ir_version_{}; | |||
| std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; | |||
| std::map<std::string, onnx::TensorProto> default_para_map_; | |||
| onnx::ModelProto *onnx_model_; | |||
| FuncGraphPtr func_graph_; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | |||
| @@ -37,7 +37,6 @@ using schema::QuantParamT; | |||
| using schema::TensorT; | |||
| using schema::Format::Format_NCHW; | |||
| using schema::Format::Format_NHWC; | |||
| using STATUS = int; | |||
| std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor); | |||
| @@ -64,7 +64,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/primitive_adjust_pass.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| add_subdirectory(../anf_exporter anf_exporter) | |||
| add_subdirectory(parser/caffe) | |||
| add_subdirectory(parser/tflite) | |||
| @@ -158,7 +157,6 @@ target_link_libraries(converter_lite PRIVATE | |||
| tf_parser_mid | |||
| caffe_parser_mid | |||
| onnx_parser_mid | |||
| anf_importer_mid | |||
| anf_exporter_mid | |||
| graph_pass_mid | |||
| fusion_mid | |||
| @@ -17,12 +17,7 @@ | |||
| #include "tools/converter/converter.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "ir/func_graph.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/storage.h" | |||
| #include "parser/caffe/caffe_converter.h" | |||
| @@ -30,68 +25,48 @@ | |||
| #include "parser/onnx/onnx_converter.h" | |||
| #include "parser/tf/tf_converter.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/anf_importer/import_from_mindir.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "include/version.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| using FmkType = converter::FmkType; | |||
| static const char *DELIM_SLASH = "/"; | |||
| Converter::Converter() { | |||
| this->transform = new GraphDefTransform; | |||
| this->anfTransform = new AnfTransform; | |||
| } | |||
| Converter::~Converter() { | |||
| delete modelParser; | |||
| delete modelImporter; | |||
| delete transform; | |||
| delete anfTransform; | |||
| } | |||
| class MindsporeImporter : public Converter { | |||
| public: | |||
| MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); } | |||
| ~MindsporeImporter() override = default; | |||
| }; | |||
| MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| // parse the model and weight file to generate inference data structure | |||
| FuncGraphPtr graph = nullptr; | |||
| if (flag->fmk == converter::FmkType_MS) { | |||
| MS_ASSERT(nullptr != modelImporter); | |||
| int status = modelImporter->Import(flag); | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| graph = modelImporter->GetResult(); | |||
| if (graph == nullptr) { | |||
| std::unique_ptr<Converter> Converter::CreateConverter(converter::FmkType fmk) { | |||
| switch (fmk) { | |||
| case FmkType::FmkType_MS: | |||
| return std::make_unique<MindsporeImporter>(); | |||
| case FmkType::FmkType_CAFFE: | |||
| return std::make_unique<CaffeConverter>(); | |||
| case FmkType::FmkType_TFLITE: | |||
| return std::make_unique<TfliteConverter>(); | |||
| case FmkType::FmkType_ONNX: | |||
| return std::make_unique<OnnxConverter>(); | |||
| case FmkType::FmkType_TF: | |||
| return std::make_unique<TFConverter>(); | |||
| default: { | |||
| return nullptr; | |||
| } | |||
| graph->set_attr("graph_name", MakeValue("main_graph")); | |||
| graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS))); | |||
| } else { | |||
| MS_ASSERT(nullptr != modelParser); | |||
| const std::string modelFile = flag->modelFile; | |||
| const std::string weightFile = flag->weightFile; | |||
| graph = modelParser->Parse(modelFile, weightFile, flag->quantType); | |||
| } | |||
| } | |||
| MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) { | |||
| if (flag == nullptr) { | |||
| MS_LOG(ERROR) << "Input flag is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto graph = BuildFuncGraph(flag->modelFile, flag->weightFile, flag->quantType); | |||
| if (graph == nullptr) { | |||
| MS_LOG(ERROR) << "Parser/Import model return nullptr"; | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "import success"; | |||
| graph = anfTransform->Transform(graph, flag); | |||
| // funcgraph compile | |||
| graph = funcgraph_transform_->Transform(graph, flag.get()); | |||
| if (graph == nullptr) { | |||
| MS_LOG(ERROR) << "Transform anf graph return nullptr"; | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "Run anfTransform success"; | |||
| // anf -- fb | |||
| // protobuf -> flatbuf | |||
| auto meta_graph = Export(graph); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Export to meta graph return nullptr"; | |||
| @@ -99,91 +74,75 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| } | |||
| MS_LOG(INFO) << "export success"; | |||
| // transform | |||
| transform->SetGraphDef(meta_graph); | |||
| auto status = transform->Transform(*flag); | |||
| // metagraph compile | |||
| metagraph_transform_->SetGraphDef(meta_graph); | |||
| auto status = metagraph_transform_->Transform(*flag); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Transform meta graph failed " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "run fbTransform success"; | |||
| return meta_graph; | |||
| } | |||
| int RunConverter(int argc, const char **argv) { | |||
| std::ostringstream oss; | |||
| std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags); | |||
| if (flags == nullptr) { | |||
| MS_LOG(ERROR) << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED); | |||
| std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED) << std::endl; | |||
| oss.clear(); | |||
| oss << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED); | |||
| MS_LOG(ERROR) << oss.str(); | |||
| std::cout << oss.str() << std::endl; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| auto status = flags->Init(argc, argv); | |||
| if (status != RET_OK) { | |||
| if (status != RET_SUCCESS_EXIT) { | |||
| MS_LOG(ERROR) << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status) << std::endl; | |||
| std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status) << std::endl; | |||
| oss.clear(); | |||
| oss << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status); | |||
| MS_LOG(ERROR) << oss.str(); | |||
| std::cout << oss.str() << std::endl; | |||
| } | |||
| std::cout << GetErrorInfo(status) << std::endl; | |||
| return status; | |||
| } | |||
| // Load graph | |||
| std::string modelName = flags->modelFile.substr(flags->modelFile.find_last_of(DELIM_SLASH) + 1); | |||
| MS_LOG(INFO) << "start reading model file"; | |||
| auto fb_graph = new (std::nothrow) MetaGraphT; | |||
| switch (flags->fmk) { | |||
| case FmkType::FmkType_MS: { | |||
| MindsporeImporter mindsporeImporter; | |||
| fb_graph = mindsporeImporter.Convert(flags.get()); | |||
| break; | |||
| } | |||
| case FmkType::FmkType_CAFFE: { | |||
| CaffeConverter caffeConverter; | |||
| fb_graph = caffeConverter.Convert(flags.get()); | |||
| } break; | |||
| case FmkType::FmkType_TFLITE: { | |||
| TfliteConverter tfLiteConverter; | |||
| fb_graph = tfLiteConverter.Convert(flags.get()); | |||
| } break; | |||
| case FmkType::FmkType_ONNX: { | |||
| OnnxConverter onnxConverter; | |||
| fb_graph = onnxConverter.Convert(flags.get()); | |||
| } break; | |||
| case FmkType::FmkType_TF: { | |||
| TFConverter tfConverter; | |||
| fb_graph = tfConverter.Convert(flags.get()); | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " | |||
| << GetErrorInfo(RET_INPUT_PARAM_INVALID); | |||
| std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " | |||
| << GetErrorInfo(RET_INPUT_PARAM_INVALID) << std::endl; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| MS_LOG(DEBUG) << "start reading model file"; | |||
| auto converter = Converter::CreateConverter(flags->fmk); | |||
| if (converter == nullptr) { | |||
| oss.clear(); | |||
| oss << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " | |||
| << GetErrorInfo(RET_INPUT_PARAM_INVALID); | |||
| MS_LOG(ERROR) << oss.str(); | |||
| std::cout << oss.str() << std::endl; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| auto meta_graph = converter->Convert(flags); | |||
| NoSupportOp::GetInstance()->PrintOps(); | |||
| status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); | |||
| if (fb_graph == nullptr) { | |||
| MS_LOG(ERROR) << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); | |||
| std::cout << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status) << std::endl; | |||
| if (meta_graph == nullptr) { | |||
| oss.clear(); | |||
| oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); | |||
| MS_LOG(ERROR) << oss.str(); | |||
| std::cout << oss.str() << std::endl; | |||
| return status; | |||
| } | |||
| // save graph to file | |||
| Storage storage; | |||
| fb_graph->version = Version(); | |||
| status = storage.Save(*fb_graph, flags->outputFile); | |||
| meta_graph->version = Version(); | |||
| status = Storage::Save(*meta_graph, flags->outputFile); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SAVE GRAPH FAILED:" << status << " " << GetErrorInfo(status); | |||
| std::cout << "SAVE GRAPH FAILED:" << status << " " << GetErrorInfo(status) << std::endl; | |||
| oss.clear(); | |||
| oss << "SAVE GRAPH FAILED:" << status << " " << GetErrorInfo(status); | |||
| MS_LOG(ERROR) << oss.str(); | |||
| std::cout << oss.str() << std::endl; | |||
| return status; | |||
| } | |||
| delete fb_graph; | |||
| MS_LOG(INFO) << "CONVERT RESULT SUCCESS:" << status; | |||
| std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl; | |||
| delete meta_graph; | |||
| oss.clear(); | |||
| oss << "CONVERT RESULT SUCCESS:" << status; | |||
| MS_LOG(INFO) << oss.str(); | |||
| std::cout << oss.str() << std::endl; | |||
| return status; | |||
| } | |||
| } // namespace lite | |||
| @@ -22,24 +22,41 @@ | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/anf_importer/anf_importer.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "load_mindir/load_model.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Converter { | |||
| public: | |||
| Converter(); | |||
| virtual ~Converter(); | |||
| virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags); | |||
| static std::unique_ptr<Converter> CreateConverter(converter::FmkType fmk); | |||
| virtual ~Converter() = default; | |||
| virtual schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag); | |||
| virtual FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, | |||
| schema::QuantType quant_type) = 0; | |||
| protected: | |||
| ModelParser *modelParser = nullptr; | |||
| AnfImporter *modelImporter = nullptr; | |||
| GraphDefTransform *transform = nullptr; | |||
| AnfTransform *anfTransform = nullptr; | |||
| Converter() = default; | |||
| std::unique_ptr<GraphDefTransform> metagraph_transform_ = std::make_unique<GraphDefTransform>(); | |||
| std::unique_ptr<AnfTransform> funcgraph_transform_ = std::make_unique<AnfTransform>(); | |||
| }; | |||
| class MindsporeImporter : public Converter { | |||
| public: | |||
| MindsporeImporter() = default; | |||
| ~MindsporeImporter() override = default; | |||
| FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, | |||
| schema::QuantType quant_type) override { | |||
| return LoadMindIR(model_file); | |||
| } | |||
| }; | |||
| int RunConverter(int argc, const char **argv); | |||
| @@ -35,13 +35,7 @@ class ModelParser { | |||
| virtual ~ModelParser() = default; | |||
| virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) { | |||
| return nullptr; | |||
| } | |||
| protected: | |||
| virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type = QuantType_QUANT_NONE) = 0; | |||
| const QuantType &quant_type) = 0; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -1,22 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/caffe/caffe_converter.h" | |||
| #include "tools/converter/parser/caffe/caffe_model_parser.h" | |||
| namespace mindspore::lite { | |||
| CaffeConverter::CaffeConverter() { modelParser = new CaffeModelParser(); } | |||
| } // namespace mindspore::lite | |||
| @@ -21,13 +21,20 @@ | |||
| #include <memory> | |||
| #include "tools/converter/converter.h" | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include "tools/converter/parser/caffe/caffe_model_parser.h" | |||
| namespace mindspore::lite { | |||
| class CaffeConverter : public Converter { | |||
| public: | |||
| CaffeConverter(); | |||
| CaffeConverter() = default; | |||
| ~CaffeConverter() override = default; | |||
| FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, | |||
| schema::QuantType quant_type) override { | |||
| CaffeModelParser parser; | |||
| return parser.Parse(model_file, weight_file, quant_type); | |||
| } | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -26,6 +26,8 @@ | |||
| #include "ops/return.h" | |||
| #include "ops/make_tuple.h" | |||
| #include "ops/tuple_get_item.h" | |||
| #include "ir/func_graph.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore::lite { | |||
| CaffeModelParser::CaffeModelParser() = default; | |||
| @@ -426,10 +428,4 @@ bool CaffeModelParser::IsSkipedLayer(const caffe::LayerParameter &layer) { | |||
| } | |||
| return layer.include_size() == 1 && layer.include(0).phase() == caffe::TRAIN; | |||
| } | |||
| MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) { | |||
| return nullptr; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -35,9 +35,6 @@ class CaffeModelParser : public ModelParser { | |||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override; | |||
| MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override; | |||
| private: | |||
| STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <algorithm> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| #include "ops/constant.h" | |||
| #include "src/param_value_lite.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -1,25 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_converter.h" | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OnnxConverter::OnnxConverter() { modelParser = new OnnxModelParser(); } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,14 +21,21 @@ | |||
| #include <memory> | |||
| #include "tools/converter/converter.h" | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxConverter : public Converter { | |||
| public: | |||
| OnnxConverter(); | |||
| OnnxConverter() = default; | |||
| ~OnnxConverter() override = default; | |||
| FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, | |||
| schema::QuantType quant_type) override { | |||
| OnnxModelParser parser; | |||
| return parser.Parse(model_file, weight_file, quant_type); | |||
| } | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -26,6 +26,9 @@ | |||
| #include "ops/return.h" | |||
| #include "ops/make_tuple.h" | |||
| #include "ops/tuple_get_item.h" | |||
| #include "ir/func_graph.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -41,16 +41,10 @@ class OnnxModelParser : public ModelParser { | |||
| ~OnnxModelParser() override = default; | |||
| MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override { | |||
| return nullptr; | |||
| } | |||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override; | |||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||
| static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, | |||
| const ParamValueLitePtr ¶m_value_lite); | |||
| static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, const ParamValueLitePtr ¶m_value); | |||
| private: | |||
| STATUS InitOriginModel(const std::string &model_file); | |||
| @@ -1,22 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_converter.h" | |||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| TFConverter::TFConverter() { modelParser = new TFModelParser(); } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,13 +18,21 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "tools/converter/converter.h" | |||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFConverter : public Converter { | |||
| public: | |||
| TFConverter(); | |||
| TFConverter() = default; | |||
| ~TFConverter() override = default; | |||
| ~TFConverter() = default; | |||
| FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, | |||
| schema::QuantType quant_type) override { | |||
| TFModelParser parser; | |||
| return parser.Parse(model_file, weight_file, quant_type); | |||
| } | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -30,6 +30,7 @@ | |||
| #include "ops/tuple_get_item.h" | |||
| #include "ops/while.h" | |||
| #include "ir/anf.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -635,12 +636,6 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr | |||
| return RET_OK; | |||
| } | |||
| schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead"; | |||
| return nullptr; | |||
| } | |||
| STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||
| const std::vector<std::string> &input_names, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| @@ -28,7 +28,7 @@ | |||
| #include "securec/include/securec.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "mindspore/lite/src/param_value_lite.h" | |||
| #include "src/param_value_lite.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -39,10 +39,6 @@ class TFModelParser : public ModelParser { | |||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | |||
| protected: | |||
| schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, const ParamValueLitePtr ¶m_value); | |||
| STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, | |||
| @@ -1,22 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tflite/tflite_converter.h" | |||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | |||
| namespace mindspore::lite { | |||
| TfliteConverter::TfliteConverter() { modelParser = new TfliteModelParser(); } | |||
| } // namespace mindspore::lite | |||
| @@ -22,13 +22,20 @@ | |||
| #include <map> | |||
| #include "tools/converter/converter.h" | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | |||
| namespace mindspore::lite { | |||
| class TfliteConverter : public Converter { | |||
| public: | |||
| TfliteConverter(); | |||
| TfliteConverter() = default; | |||
| ~TfliteConverter() override = default; | |||
| FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, | |||
| schema::QuantType quant_type) override { | |||
| TfliteModelParser parser; | |||
| return parser.Parse(model_file, weight_file, quant_type); | |||
| } | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -19,15 +19,16 @@ | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "ops/return.h" | |||
| #include "ops/make_tuple.h" | |||
| #include "ops/tuple_get_item.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) { | |||
| size_t size = 0; | |||
| tflite_model_buf_ = ReadFile(model_path, &size); | |||
| @@ -453,10 +454,4 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const | |||
| } | |||
| return RET_OK; | |||
| } | |||
| MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) { | |||
| return nullptr; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| @@ -34,8 +34,6 @@ class TfliteModelParser : public ModelParser { | |||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override; | |||
| MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override; | |||
| private: | |||
| std::unordered_map<int, AnfNodePtr> nodes_; | |||