diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index b4431517bd..aada19fc4c 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -254,12 +254,12 @@ if (ENABLE_CONVERTER) ${TEST_DIR}/st/converter_test.cc ${TEST_DIR}/st/control_flow_test.cc ${TEST_DIR}/st/sub_graph_test.cc + ${TEST_DIR}/common/import_from_meta_graphT.cc ${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc - ${TEST_DIR}/ut/tools/optimizer/fusion/import_from_meta_graphT.cc ) endif() @@ -301,6 +301,8 @@ endif() if (PLATFORM_ARM) target_link_libraries(lite-test log) +else() + target_link_libraries(lite-test ${SECUREC_LIBRARY} pthread) endif() if (SUPPORT_NPU) @@ -310,7 +312,6 @@ endif () if (ENABLE_CONVERTER) add_dependencies(lite-test fbs_inner_src) target_link_libraries(lite-test - anf_importer_mid anf_exporter_mid tflite_parser_mid caffe_parser_mid @@ -320,12 +321,10 @@ if (ENABLE_CONVERTER) fusion_mid quantizer_mid proto_mid - pthread mindspore::protobuf mindspore::eigen mindspore::json -Wl,--whole-archive mindspore_core -Wl,--no-whole-archive mindspore::glog - ${SECUREC_LIBRARY} ) endif() diff --git a/mindspore/lite/test/common/import_from_meta_graphT.cc b/mindspore/lite/test/common/import_from_meta_graphT.cc new file mode 100644 index 0000000000..86a9331899 --- /dev/null +++ b/mindspore/lite/test/common/import_from_meta_graphT.cc @@ -0,0 +1,175 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/inner/model_generated.h" +#include "frontend/operator/ops.h" +#include "src/param_value_lite.h" +#include "src/common/log_adapter.h" +#include "tools/converter/converter_context.h" +#include "include/errorcode.h" +#include "test/common/import_from_meta_graphT.h" +#include "ir/func_graph.h" + +namespace mindspore::lite { +AnfNodePtr AnfImporterFromMetaGraphT::GetNode(int tensor_id) { + auto n = nodes_.find(tensor_id); + if (n == nodes_.end()) { + return nullptr; + } + return n->second; +} + +void AnfImporterFromMetaGraphT::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } + +int AnfImporterFromMetaGraphT::ConverterConstTensor() { + MS_ASSERT(nullptr != meta_graph_); + MS_ASSERT(nullptr != func_graph_); + for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { + auto &tensor = meta_graph_->allTensors.at(i); + MS_ASSERT(tensor != nullptr); + if (tensor->nodeType != schema::NodeType::NodeType_ValueNode) { + continue; + } + auto parameter = func_graph_->add_parameter(); + std::vector shape(tensor->dims.size()); + std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); + auto type_id = static_cast(tensor->dataType); + auto type_ptr = TypeIdToType(type_id); + std::vector shape_vector; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + MS_ASSERT(nullptr != abstract_tensor); + parameter->set_abstract(abstract_tensor); + if (!tensor->name.empty()) { + parameter->set_name(tensor->name); + } else { + parameter->set_name("const-" + std::to_string(i)); + } + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(nullptr != param_value); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(type_id); + param_value->set_format(tensor->format); + if (!tensor->data.empty()) { + auto size = tensor->data.size(); + char *tensor_data = new (std::nothrow) char[size]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_MEMORY_FAILED; + } + auto ret = memcpy_s(tensor_data, size, tensor->data.data(), size); + if (EOK != ret) { + MS_LOG(ERROR) << "memcpy_s error"; + delete[] tensor_data; + return RET_MEMORY_FAILED; + } + param_value->SetTensorData(tensor_data, size); + parameter->set_default_param(param_value); + } else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) == + meta_graph_->inputIndex.end()) { + parameter->set_default_param(param_value); + } + AddNode(i, parameter); + } + return RET_OK; +} + +ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr &cNode) { + return nullptr; +} + +abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTensor( + const std::unique_ptr &tensor) { + MS_ASSERT(nullptr != tensor); + std::vector shape(tensor->dims.size()); + std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); + auto type_id = static_cast(tensor->dataType); + auto type_ptr = TypeIdToType(type_id); + std::vector shape_vector; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto ptr = std::make_shared(type_ptr, shape_vector); + MS_ASSERT(nullptr != ptr); + return ptr; +} + +int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr &src_cnode, + const CNodePtr &dst_cnode) { + return RET_ERROR; +} + +int AnfImporterFromMetaGraphT::ConverterCNode() { + MS_ASSERT(nullptr != meta_graph_); + MS_ASSERT(nullptr != func_graph_); + for (const auto &cNode : meta_graph_->nodes) { + MS_ASSERT(nullptr != cNode); + auto anf_primitive = ConvertPrimitive(cNode); + if (anf_primitive == nullptr) { + MS_LOG(ERROR) << "cannot obtain anf primitive"; + return RET_NULL_PTR; + } + std::vector op_inputs = {anf_primitive}; + for (int j : cNode->inputIndex) { + auto node = GetNode(j); + if (nullptr == node) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_NULL_PTR; + } + op_inputs.push_back(node); + } + auto new_cnode = func_graph_->NewCNode(op_inputs); + MS_ASSERT(nullptr != new_cnode); + new_cnode->set_fullname_with_scope(cNode->name); + auto status = ConvertAbstract(cNode, new_cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "ConvertAbstract failed."; + return status; + } + } + return RET_OK; +} + +int AnfImporterFromMetaGraphT::AddReturnCNode() { return RET_ERROR; } + +FuncGraphPtr AnfImporterFromMetaGraphT::Fb2Anf(schema::MetaGraphT *meta_graph) { + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "meta_graph is null"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); + return nullptr; + } + AnfImporterFromMetaGraphT anfImporterFromMetaGraphT(meta_graph); + auto ret = anfImporterFromMetaGraphT.ConverterConstTensor(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; + return nullptr; + } + ret = anfImporterFromMetaGraphT.ConverterCNode(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "ConverterCNode failed " << ret; + return nullptr; + } + ret = anfImporterFromMetaGraphT.AddReturnCNode(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "AddReturnCNode failed " << ret; + return nullptr; + } + return anfImporterFromMetaGraphT.func_graph_; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/import_from_meta_graphT.h b/mindspore/lite/test/common/import_from_meta_graphT.h similarity index 72% rename from mindspore/lite/test/ut/tools/optimizer/fusion/import_from_meta_graphT.h rename to mindspore/lite/test/common/import_from_meta_graphT.h index f3ea8c533a..5e75004f4f 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/import_from_meta_graphT.h +++ b/mindspore/lite/test/common/import_from_meta_graphT.h @@ -19,24 +19,26 @@ #include #include +#include #include "schema/inner/model_generated.h" -#include "tools/anf_importer/anf_importer.h" #include "abstract/abstract_value.h" +#include "ir/func_graph.h" namespace mindspore::lite { -class AnfImporterFromMetaGraphT : public AnfImporter { +class AnfImporterFromMetaGraphT { public: - AnfImporterFromMetaGraphT(schema::MetaGraphT *meta_graph, FuncGraphPtr func_graph) - : meta_graph_(meta_graph), func_graph_(std::move(func_graph)) {} + virtual ~AnfImporterFromMetaGraphT() = default; - ~AnfImporterFromMetaGraphT() override = default; - - FuncGraphPtr GetResult() override; + static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph); private: - int ConverterConstTensor() override; + explicit AnfImporterFromMetaGraphT(schema::MetaGraphT *meta_graph) : meta_graph_(meta_graph) { + this->func_graph_ = std::make_shared(); + } + + int ConverterConstTensor(); - int ConverterCNode() override; + int ConverterCNode(); ValueNodePtr ConvertPrimitive(const std::unique_ptr &cNode); @@ -44,9 +46,14 @@ class AnfImporterFromMetaGraphT : public AnfImporter { int ConvertAbstract(const std::unique_ptr &src_cnode, const CNodePtr &dst_cnode); - int AddReturnCNode() override; + int AddReturnCNode(); + + AnfNodePtr GetNode(int tensor_id); + + void AddNode(int tensor_id, AnfNodePtr node); private: + std::unordered_map nodes_; schema::MetaGraphT *meta_graph_; FuncGraphPtr func_graph_; }; diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc index 204b1c550d..7e5cd0be01 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc @@ -17,18 +17,11 @@ #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" #include #include "schema/inner/model_generated.h" -#include "tools/converter/parser/tflite/tflite_model_parser.h" namespace mindspore { -schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) { - lite::TfliteModelParser parser; - meta_graph = parser.ParseToFb(model_path, weight_path, schema::QuantType_QUANT_NONE); - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; - return nullptr; - } - return meta_graph; +schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const std::string &model_path, const std::string &weight_path) { + return nullptr; } void TestTfliteParser::TearDown() { free(meta_graph); } diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc index 6cd45a308f..60f9f3ce8c 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc @@ -21,12 +21,12 @@ #include "include/lite_session.h" #include "include/context.h" #include "include/errorcode.h" -#include "import_from_meta_graphT.h" #include "src/common/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/anf_exporter/anf_exporter.h" +#include "test/common/import_from_meta_graphT.h" namespace mindspore { class ConstantFoldingFusionTest : public mindspore::CommonTest { @@ -370,7 +370,7 @@ MetaGraphTptr BuildSplitGraph() { } // namespace TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) { auto meta_graph = BuildGraph(schema::PrimitiveType_AddFusion, new schema::AddFusionT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -383,7 +383,7 @@ TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) { TEST_F(ConstantFoldingFusionTest, TestMixedConstantFold) { auto meta_graph = BuildMixGraph(); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -396,7 +396,7 @@ TEST_F(ConstantFoldingFusionTest, TestMixedConstantFold) { TEST_F(ConstantFoldingFusionTest, TestSubConstantFold) { auto meta_graph = BuildGraph(schema::PrimitiveType_SubFusion, new schema::SubFusionT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -409,7 +409,7 @@ TEST_F(ConstantFoldingFusionTest, TestSubConstantFold) { TEST_F(ConstantFoldingFusionTest, TestMulConstantFold) { auto meta_graph = BuildGraph(schema::PrimitiveType_MulFusion, new schema::MulFusionT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -423,7 +423,7 @@ TEST_F(ConstantFoldingFusionTest, TestMulConstantFold) { TEST_F(ConstantFoldingFusionTest, TestTransposeConstantFold) { auto transposeT = new schema::TransposeT; auto meta_graph = BuildGraph(schema::PrimitiveType_Transpose, transposeT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -470,7 +470,7 @@ TEST_F(ConstantFoldingFusionTest, TestStackConstantFold) { auto stackT = new schema::StackT; stackT->axis[0] = 1; auto meta_graph = BuildGraph(schema::PrimitiveType_Stack, stackT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -484,7 +484,7 @@ TEST_F(ConstantFoldingFusionTest, TestStackConstantFold) { TEST_F(ConstantFoldingFusionTest, TestSliceConstantFold) { auto sliceT = new schema::SliceFusionT; auto meta_graph = BuildGraph(schema::PrimitiveType_SliceFusion, sliceT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -498,7 +498,7 @@ TEST_F(ConstantFoldingFusionTest, TestSliceConstantFold) { TEST_F(ConstantFoldingFusionTest, TestShapeConstantFold) { auto shapeT = new schema::ShapeT; auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Shape, shapeT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -512,7 +512,7 @@ TEST_F(ConstantFoldingFusionTest, TestShapeConstantFold) { TEST_F(ConstantFoldingFusionTest, TestRsqrtConstantFold) { auto rsqrtT = new schema::RsqrtT; auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Rsqrt, rsqrtT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -526,7 +526,7 @@ TEST_F(ConstantFoldingFusionTest, TestRsqrtConstantFold) { TEST_F(ConstantFoldingFusionTest, TestReshapeConstantFold) { auto reshapeT = new schema::ReshapeT; auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Reshape, reshapeT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -543,7 +543,7 @@ TEST_F(ConstantFoldingFusionTest, TestRangeConstantFold) { rangeT->start = 1; rangeT->delta = 1; auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Range, rangeT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -556,7 +556,7 @@ TEST_F(ConstantFoldingFusionTest, TestRangeConstantFold) { TEST_F(ConstantFoldingFusionTest, TestMatmulConstantFold) { auto matmulT = new schema::MatMulT; auto meta_graph = BuildGraph(schema::PrimitiveType_MatMul, matmulT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -570,7 +570,7 @@ TEST_F(ConstantFoldingFusionTest, TestMatmulConstantFold) { TEST_F(ConstantFoldingFusionTest, TestExpandDimsConstantFold) { auto expandDimsT = new schema::ExpandDimsT; auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_ExpandDims, expandDimsT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -584,7 +584,7 @@ TEST_F(ConstantFoldingFusionTest, TestExpandDimsConstantFold) { TEST_F(ConstantFoldingFusionTest, TestConcatDimsConstantFold) { auto concatT = new schema::ConcatT; auto meta_graph = BuildGraph(schema::PrimitiveType_Concat, concatT); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -600,7 +600,7 @@ TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) { auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Cast, castT); auto input_tensor = meta_graph->allTensors.at(0).get(); input_tensor->dataType = kNumberTypeUInt8; - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -615,7 +615,7 @@ TEST_F(ConstantFoldingFusionTest, TestSplitConstantFold) { auto meta_graph = BuildSplitGraph(); auto input_tensor = meta_graph->allTensors.at(0).get(); input_tensor->dataType = kNumberTypeFloat32; - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto optimizer = std::make_shared(); auto pm = std::make_shared("test", false); pm->AddPass(std::make_shared()); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc index 4fe51aa9dc..327cca8a00 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc @@ -21,11 +21,11 @@ #include "include/lite_session.h" #include "include/context.h" #include "include/errorcode.h" -#include "import_from_meta_graphT.h" #include "src/common/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" +#include "test/common/import_from_meta_graphT.h" namespace mindspore { class ConvActivationFusionTest : public mindspore::CommonTest { @@ -136,7 +136,7 @@ MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, schema::ActivationType } // namespace TEST_F(ConvActivationFusionTest, TestConvReluNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); @@ -149,7 +149,7 @@ TEST_F(ConvActivationFusionTest, TestConvReluNode) { TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU6); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); @@ -162,7 +162,7 @@ TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_LEAKY_RELU); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc index 3a3b58e0c4..2e51aa20df 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc @@ -21,11 +21,11 @@ #include "include/lite_session.h" #include "include/context.h" #include "include/errorcode.h" -#include "import_from_meta_graphT.h" #include "src/common/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" +#include "test/common/import_from_meta_graphT.h" namespace mindspore { class ConvBiasAddFusionTest : public mindspore::CommonTest { @@ -145,7 +145,7 @@ MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, schema::PrimitiveType } // namespace TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_BiasAdd); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); @@ -156,7 +156,7 @@ TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_AddFusion); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); @@ -166,7 +166,7 @@ TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_MatMul); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc index 014d92842f..8ae5b47c85 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc @@ -21,11 +21,11 @@ #include "include/lite_session.h" #include "include/context.h" #include "include/errorcode.h" -#include "import_from_meta_graphT.h" #include "src/common/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" +#include "test/common/import_from_meta_graphT.h" namespace mindspore { class ConvBNFusionTest : public mindspore::CommonTest { @@ -262,7 +262,7 @@ MetaGraphTptr BuildTFGraph(schema::PrimitiveType conv_type) { } // namespace TEST_F(ConvBNFusionTest, TestConvAddNode) { auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2DFusion); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); @@ -272,7 +272,7 @@ TEST_F(ConvBNFusionTest, TestConvAddNode) { TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) { auto meta_graph = BuildTFGraph(schema::PrimitiveType_Conv2DFusion); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc index ab10756009..ce3c36cd54 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc @@ -21,11 +21,11 @@ #include "include/lite_session.h" #include "include/context.h" #include "include/errorcode.h" -#include "import_from_meta_graphT.h" #include "src/common/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" +#include "test/common/import_from_meta_graphT.h" namespace mindspore { class ConvScaleFusionTest : public mindspore::CommonTest { @@ -187,7 +187,7 @@ MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, bool conv_with_bias) { } // namespace TEST_F(ConvScaleFusionTest, TestConvScaleNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, true); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); @@ -198,7 +198,7 @@ TEST_F(ConvScaleFusionTest, TestConvScaleNode) { TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, false); - auto func_graph = lite::Fb2Anf(meta_graph.get()); + auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); auto new_graph = anf_transform->Transform(func_graph); ASSERT_NE(nullptr, new_graph); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/import_from_meta_graphT.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/import_from_meta_graphT.cc deleted file mode 100644 index 4d7cb892cd..0000000000 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/import_from_meta_graphT.cc +++ /dev/null @@ -1,302 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "schema/inner/model_generated.h" -#include "frontend/operator/ops.h" -#include "src/param_value_lite.h" -#include "src/common/log_adapter.h" -#include "tools/converter/quant_param_holder.h" -#include "tools/converter/converter_context.h" -#include "include/errorcode.h" -#include "import_from_meta_graphT.h" - -namespace mindspore::lite { -int AnfImporterFromMetaGraphT::ConverterConstTensor() { - MS_ASSERT(nullptr != meta_graph_); - MS_ASSERT(nullptr != func_graph_); - for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { - auto &tensor = meta_graph_->allTensors.at(i); - MS_ASSERT(tensor != nullptr); - if (tensor->nodeType != schema::NodeType::NodeType_ValueNode) { - continue; - } - auto parameter = func_graph_->add_parameter(); - std::vector shape(tensor->dims.size()); - std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); - auto type_id = static_cast(tensor->dataType); - auto type_ptr = TypeIdToType(type_id); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - MS_ASSERT(nullptr != abstract_tensor); - parameter->set_abstract(abstract_tensor); - if (!tensor->name.empty()) { - parameter->set_name(tensor->name); - } else { - parameter->set_name("const-" + std::to_string(i)); - } - - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(nullptr != param_value); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(type_id); - param_value->set_format(tensor->format); - if (!tensor->data.empty()) { - auto size = tensor->data.size(); - char *tensor_data = new (std::nothrow) char[size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - auto ret = memcpy_s(tensor_data, size, tensor->data.data(), size); - if (EOK != ret) { - MS_LOG(ERROR) << "memcpy_s error"; - delete[] tensor_data; - return RET_MEMORY_FAILED; - } - param_value->SetTensorData(tensor_data, size); - parameter->set_default_param(param_value); - } else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) == - meta_graph_->inputIndex.end()) { - parameter->set_default_param(param_value); - } - AddNode(i, parameter); - } - return RET_OK; -} - -ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr &cNode) { - // MS_ASSERT(nullptr != meta_graph_); - // MS_ASSERT(nullptr != cNode); - // auto primitiveCValue = PrimitiveC::Create(cNode->primitive.release()); - // if (primitiveCValue == nullptr) { - // MS_LOG(ERROR) << "fail to convert primitive"; - // return nullptr; - // } - // cNode->primitive = nullptr; - // // add quant parameter - // auto quant_params_holder = std::make_shared(); - // for (auto index : cNode->inputIndex) { - // if (!meta_graph_->allTensors[index]->quantParams.empty()) { - // std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); - // std::transform( - // meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), - // quant_params.begin(), - // [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); - // quant_params_holder->AddInputQuantParam(quant_params); - // } else { - // std::vector notinited_quant_params(1); - // quant_params_holder->AddInputQuantParam(notinited_quant_params); - // } - // } - // for (auto index : cNode->outputIndex) { - // if (!meta_graph_->allTensors[index]->quantParams.empty()) { - // std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); - // std::transform( - // meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), - // quant_params.begin(), - // [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); - // quant_params_holder->AddOutputQuantParam(quant_params); - // } else { - // std::vector notinited_quant_params(1); - // quant_params_holder->AddOutputQuantParam(notinited_quant_params); - // } - // } - // primitiveCValue->AddAttr("quant_params", quant_params_holder); - // auto value_node = NewValueNode(std::shared_ptr(primitiveCValue)); - // return value_node; - return nullptr; -} - -abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTensor( - const std::unique_ptr &tensor) { - MS_ASSERT(nullptr != tensor); - std::vector shape(tensor->dims.size()); - std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); - auto type_id = static_cast(tensor->dataType); - auto type_ptr = TypeIdToType(type_id); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto ptr = std::make_shared(type_ptr, shape_vector); - MS_ASSERT(nullptr != ptr); - return ptr; -} - -int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr &src_cnode, - const CNodePtr &dst_cnode) { - // MS_ASSERT(nullptr != meta_graph_); - // MS_ASSERT(nullptr != src_cnode); - // MS_ASSERT(nullptr != dst_cnode); - // std::vector out_tensor_ids = src_cnode->outputIndex; - // if (out_tensor_ids.size() == 1) { - // auto out_tensor_id = out_tensor_ids.front(); - // MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); - // auto &tensor = meta_graph_->allTensors.at(out_tensor_id); - // MS_ASSERT(nullptr != tensor); - // dst_cnode->set_abstract(ConvertTensorToAbstractTensor(tensor)); - // AddNode(out_tensor_id, dst_cnode); - // } else { - // AbstractBasePtrList abstract_list; - // for (size_t i = 0; i < out_tensor_ids.size(); i++) { - // auto out_tensor_id = out_tensor_ids.at(i); - // MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); - // auto &tensor = meta_graph_->allTensors.at(out_tensor_id); - // MS_ASSERT(nullptr != tensor); - // abstract_list.emplace_back(ConvertTensorToAbstractTensor(tensor)); - // auto tuple_get_item_prim_ptr = GetTupleGetItemPrim(); - // if (tuple_get_item_prim_ptr == nullptr) { - // MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; - // return RET_NULL_PTR; - // } - // auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); - // auto get_item_value = NewValueNode(MakeValue(i)); - // if (tuple_get_item_prim == nullptr || get_item_value == nullptr) { - // MS_LOG(ERROR) << "NewValueNode is nullptr"; - // return RET_NULL_PTR; - // } - // std::vector inputs{tuple_get_item_prim, dst_cnode, get_item_value}; - // CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); - // if (get_item_cnode == nullptr) { - // MS_LOG(ERROR) << "NewCNode is nullptr"; - // return RET_NULL_PTR; - // } - // get_item_cnode->set_fullname_with_scope(src_cnode->name + "_getitem_" + std::to_string(i)); - // AddNode(out_tensor_id, get_item_cnode); - // } - // dst_cnode->set_abstract(std::make_shared(abstract_list)); - // } - return RET_OK; -} - -int AnfImporterFromMetaGraphT::ConverterCNode() { - MS_ASSERT(nullptr != meta_graph_); - MS_ASSERT(nullptr != func_graph_); - for (const auto &cNode : meta_graph_->nodes) { - MS_ASSERT(nullptr != cNode); - auto anf_primitive = ConvertPrimitive(cNode); - if (anf_primitive == nullptr) { - MS_LOG(ERROR) << "cannot obtain anf primitive"; - return RET_NULL_PTR; - } - std::vector op_inputs = {anf_primitive}; - for (int j : cNode->inputIndex) { - auto node = GetNode(j); - if (nullptr == node) { - MS_LOG(ERROR) << "Can't find input node."; - return RET_NULL_PTR; - } - op_inputs.push_back(node); - } - auto new_cnode = func_graph_->NewCNode(op_inputs); - MS_ASSERT(nullptr != new_cnode); - new_cnode->set_fullname_with_scope(cNode->name); - auto status = ConvertAbstract(cNode, new_cnode); - if (status != RET_OK) { - MS_LOG(ERROR) << "ConvertAbstract failed."; - return status; - } - } - return RET_OK; -} - -int AnfImporterFromMetaGraphT::AddReturnCNode() { - // MS_ASSERT(nullptr != meta_graph_); - // MS_ASSERT(nullptr != func_graph_); - // if (meta_graph_->outputIndex.size() > 1) { - // std::vector make_tuple_inputs; - // auto make_tuple_prim_ptr = GetMakeTuplePrim(); - // if (make_tuple_prim_ptr == nullptr) { - // MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; - // return RET_NULL_PTR; - // } - // auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); - // make_tuple_inputs.emplace_back(make_tuple_prim); - // for (auto tensor_id : meta_graph_->outputIndex) { - // auto cNode = GetNode(tensor_id); - // if (nullptr == cNode) { - // MS_LOG(ERROR) << "Can't find input node."; - // return RET_ERROR; - // } - // make_tuple_inputs.emplace_back(cNode); - // } - // auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); - // if (make_tuple_cnode == nullptr) { - // MS_LOG(ERROR) << "NewCNode is nullptr"; - // return RET_NULL_PTR; - // } - // make_tuple_cnode->set_fullname_with_scope("return tuple"); - - // std::vector op_inputs; - // auto return_prim_ptr = GetReturnPrim(); - // if (return_prim_ptr == nullptr) { - // MS_LOG(ERROR) << "GetReturnPrim return nullptr"; - // return RET_NULL_PTR; - // } - // auto value_node = NewValueNode(return_prim_ptr); - // op_inputs.emplace_back(value_node); - // op_inputs.emplace_back(make_tuple_cnode); - // auto cnode = func_graph_->NewCNode(op_inputs); - // MS_ASSERT(nullptr != cnode); - // cnode->set_fullname_with_scope("return"); - // func_graph_->set_return(cnode); - // } else { - // auto return_prim_ptr = GetReturnPrim(); - // if (return_prim_ptr == nullptr) { - // MS_LOG(ERROR) << "GetReturnPrim return nullptr"; - // return RET_NULL_PTR; - // } - // auto value_node = NewValueNode(return_prim_ptr); - // std::vector op_inputs{value_node}; - // auto cnode = GetNode(meta_graph_->outputIndex.front()); - // if (nullptr == cnode) { - // MS_LOG(ERROR) << "Can't find input node."; - // return RET_ERROR; - // } - // op_inputs.emplace_back(cnode); - // auto return_cnode = func_graph_->NewCNode(op_inputs); - // if (return_cnode == nullptr) { - // MS_LOG(ERROR) << "NewCNode is nullptr"; - // return RET_NULL_PTR; - // } - // return_cnode->set_fullname_with_scope("return"); - // func_graph_->set_return(return_cnode); - // } - return RET_OK; -} - -FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } - -FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "meta_graph is null"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); - return nullptr; - } - auto func_graph = std::make_shared(); - AnfImporterFromMetaGraphT importer(meta_graph, func_graph); - auto status = importer.Import(); - if (RET_OK != status) { - MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << status; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - return func_graph; -} -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/CMakeLists.txt b/mindspore/lite/tools/anf_importer/CMakeLists.txt deleted file mode 100644 index 10f4f0d8d8..0000000000 --- a/mindspore/lite/tools/anf_importer/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -file(GLOB ANF_IMPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - *.cc - ) -set_property(SOURCE ${ANF_IMPORTER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) -add_library(anf_importer_mid OBJECT - ${ANF_IMPORTER_SRC_LIST} - ) -add_dependencies(anf_importer_mid proto_mid) - -add_dependencies(anf_importer_mid fbs_src) -add_dependencies(anf_importer_mid fbs_inner_src) diff --git a/mindspore/lite/tools/anf_importer/anf_importer.cc b/mindspore/lite/tools/anf_importer/anf_importer.cc deleted file mode 100644 index 88a8f7c057..0000000000 --- a/mindspore/lite/tools/anf_importer/anf_importer.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/anf_importer.h" -#include -#include "schema/model_generated.h" -#include "ir/dtype.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" -namespace mindspore { -namespace lite { -int AnfImporter::Import(const converter::Flags *flag) { - auto ret = ConverterConstTensor(); - if (RET_OK != ret) { - MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; - return ret; - } - ret = ConverterCNode(); - if (RET_OK != ret) { - MS_LOG(ERROR) << "ConverterCNode failed " << ret; - return ret; - } - ret = AddReturnCNode(); - if (RET_OK != ret) { - MS_LOG(ERROR) << "AddReturnCNode failed " << ret; - return ret; - } - return RET_OK; -} - -AnfNodePtr AnfImporter::GetNode(int tensor_id) { - auto n = nodes_.find(tensor_id); - if (n == nodes_.end()) { - return nullptr; - } - return n->second; -} - -void AnfImporter::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/anf_importer/anf_importer.h b/mindspore/lite/tools/anf_importer/anf_importer.h deleted file mode 100644 index 5d55b665f8..0000000000 --- a/mindspore/lite/tools/anf_importer/anf_importer.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ -#define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ - -#include -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "base/base.h" -#include "schema/inner/model_generated.h" -#include "tools/converter/converter_flags.h" - -namespace mindspore::lite { -class AnfImporter { - public: - AnfImporter() = default; - - virtual ~AnfImporter() = default; - - virtual int Import(const converter::Flags *flag = nullptr); - - virtual FuncGraphPtr GetResult() = 0; - - protected: - // convert const tensor into parameter and save in nodes_ - virtual int ConverterConstTensor() = 0; - // convert other node into cnode and save in nodes_ - virtual int ConverterCNode() = 0; - - virtual int AddReturnCNode() = 0; - - AnfNodePtr GetNode(int tensor_id); - - void AddNode(int tensor_id, AnfNodePtr node); - - protected: - std::unordered_map nodes_; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_ANF_IMPORTER_H_ diff --git a/mindspore/lite/tools/anf_importer/import_from_mindir.cc b/mindspore/lite/tools/anf_importer/import_from_mindir.cc deleted file mode 100644 index 9733be93be..0000000000 --- a/mindspore/lite/tools/anf_importer/import_from_mindir.cc +++ /dev/null @@ -1,903 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/import_from_mindir.h" -#include -#include -#include -#include -#include -#include -#include -#include "ops/make_tuple.h" -#include "ops/return.h" -#include "frontend/operator/ops.h" -#include "include/errorcode.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "securec/include/securec.h" -#include "src/tensor.h" -#include "src/param_value_lite.h" -#include "proto/onnx.pb.h" -#include "src/common/log_adapter.h" -#include "tools/common/protobuf_utils.h" -#include "tools/common/graph_util.h" -#include "load_mindir/load_model.h" - -using string = std::string; -using int32 = int32_t; -using int64 = int64_t; -using uint64 = uint64_t; - -namespace mindspore::lite { -static constexpr char kConstantValueNode[] = "Constant"; - -enum ParseForm : int { - FORM_PARSE_TYPE = 0, - FORM_PARSE_SCALAR = 1, - FORM_PARSE_TENSOR = 2, -}; - -static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; - -static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, -}; - -std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, - const std::unordered_map &kv) { - std::string str = attr_name; - auto replace = [&](const string &orgStr, const string &newStr) { - std::string::size_type pos(0); - while ((pos = str.find(orgStr)) != std::string::npos) { - str.replace(pos, orgStr.length(), newStr); - } - return str; - }; - // remove "scalar:" - str = replace("scalar:", ""); - // remove "Tuple" - str = replace("Tuple", ""); - // remove "List" - str = replace("List", ""); - std::stack rules; - std::stack value; - int num = 0, count = 0; - for (size_t i = 0; i < str.length(); i++) { - if (str[i] == '[') { - rules.push("["); - } else if (str[i] == ']') { - // rules - std::vector vec; - while (rules.top() != "[") { - rules.pop(); - vec.push_back(value.top()); - value.pop(); - } - // pop "[" - rules.pop(); - // make tuple for names - std::string res = "dummy"; - // make tuple for values - reverse(vec.begin(), vec.end()); - auto vt = std::make_shared(vec); - if (rules.empty() && value.empty()) { - return vt; - } - rules.push(res); - value.push(vt); - } else if (str[i] == ',') { - continue; - } else { - count++; - if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { - auto value_name = str.substr(i - count + 1, count); - value.push(kv.at(value_name)); - rules.push(value_name); - count = 0; - num++; - } - } - } - return {}; -} - -std::shared_ptr ParserAttrShape( - const std::string &attr_name, const std::unordered_map &kv) { - std::string str = attr_name; - auto replace = [&](const string &orgStr, const string &newStr) { - std::string::size_type pos(0); - while ((pos = str.find(orgStr)) != std::string::npos) { - str.replace(pos, orgStr.length(), newStr); - } - return str; - }; - // remove "scalar:" - str = replace("shape:", ""); - // remove "Tuple" - str = replace("Tuple", ""); - // remove "List" - str = replace("List", ""); - std::stack rules; - std::stack value; - int num = 0, count = 0; - for (size_t i = 0; i < str.length(); i++) { - if (str[i] == '[') { - rules.push("["); - } else if (str[i] == ']') { - // rules - std::vector vec; - while (rules.top() != "[") { - rules.pop(); - vec.push_back(value.top()); - value.pop(); - } - // pop "[" - rules.pop(); - // make tuple for names - std::string res = "dummy"; - // make tuple for values - reverse(vec.begin(), vec.end()); - auto vt = std::make_shared(vec); - if (rules.empty() && value.empty()) { - return vt; - } - rules.push(res); - value.push(vt); - } else if (str[i] == ',') { - continue; - } else { - count++; - if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { - auto value_name = str.substr(i - count + 1, count); - value.push(kv.at(value_name)); - rules.push(value_name); - count = 0; - num++; - } - } - } - return {}; -} - -#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ - ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ - if (attr_tensor.type##_data_size() == 1) { \ - auto value = static_cast(attr_tensor.type##_data(0)); \ - return MakeValue(value); \ - } else { \ - MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ - } \ - return {}; \ - } - -PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) -PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) -PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) -PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) -PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) -PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) -PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) - -int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node, - const onnx::ValueInfoProto &value_proto) { - if (node == nullptr) { - return RET_NULL_PTR; - } - if (!value_proto.has_type() || !value_proto.has_name()) { - MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; - return RET_PARAM_INVALID; - } - node->set_name(value_proto.name()); - const auto &type_proto = value_proto.type(); - if (!type_proto.has_tensor_type()) { - MS_LOG(ERROR) << "onnx TypeProto has no tensor_type! "; - return RET_PARAM_INVALID; - } - const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); - if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { - MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; - return RET_INPUT_TENSOR_ERROR; - } - const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); - std::vector shape; - for (int i = 0; i < tensor_shape.dim_size(); ++i) { - shape.push_back(tensor_shape.dim(i).dim_value()); - } - - if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; - return RET_PARAM_INVALID; - } - - auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - node->set_abstract(abstract_tensor); - - if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { - auto *tensor_info = new (std::nothrow) Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); - if (tensor_info == nullptr) { - return RET_MEMORY_FAILED; - } - tensor_info->MallocData(); - const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; - std::string initial_data = initialize_proto.raw_data(); - auto *tensor_data_buf = reinterpret_cast(tensor_info->MutableData()); - if (tensor_data_buf == nullptr) { - delete tensor_info; - return RET_MEMORY_FAILED; - } - tensor_info->set_data(nullptr); - auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); - if (EOK != ret) { - MS_LOG(ERROR) << "memcpy_s error"; - delete tensor_data_buf; - delete tensor_info; - return RET_MEMORY_FAILED; - } - - ParamValueLitePtr param_value = std::make_shared(); - if (param_value == nullptr) { - delete tensor_info; - return RET_NULL_PTR; - } - param_value->SetTensorData(tensor_data_buf, tensor_info->Size()); - param_value->set_tensor_type(tensor_info->data_type()); - param_value->set_tensor_shape(tensor_info->shape()); - node->set_default_param(param_value); - delete tensor_info; - } - anfnode_build_map_[value_proto.name()] = node; - return RET_OK; -} - -int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto) { - if (outputFuncGraph == nullptr) { - return RET_NULL_PTR; - } - MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); - - for (int i = 0; i < importProto.initializer_size(); ++i) { - const onnx::TensorProto &initializer_proto = importProto.initializer(i); - if (!initializer_proto.has_name()) { - MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; - return RET_PARAM_INVALID; - } - default_para_map_[initializer_proto.name()] = initializer_proto; - } - - int status = RET_OK; - MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); - for (int i = 0; i < importProto.input_size(); ++i) { - const onnx::ValueInfoProto &input_proto = importProto.input(i); - status = BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto); - if (status != RET_OK) { - MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; - break; - } - } - return status; -} - -bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { - if (prim == nullptr) { - return false; - } - const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; - return false; - } - prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); - return true; -} - -ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - switch (attr_tensor_type) { - case onnx::TensorProto_DataType_STRING: { - return ParseAttrInScalar_string_string(attr_tensor); - } - case onnx::TensorProto_DataType_INT32: { - return ParseAttrInScalar_int32_int32(attr_tensor); - } - case onnx::TensorProto_DataType_INT64: { - return ParseAttrInScalar_int64_int64(attr_tensor); - } - case onnx::TensorProto_DataType_UINT64: { - return ParseAttrInScalar_uint64_uint64(attr_tensor); - } - case onnx::TensorProto_DataType_FLOAT: { - return ParseAttrInScalar_float_float(attr_tensor); - } - case onnx::TensorProto_DataType_DOUBLE: { - return ParseAttrInScalar_double_double(attr_tensor); - } - case onnx::TensorProto_DataType_BOOL: { - return ParseAttrInScalar_int32_bool(attr_tensor); - } - default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; - return {}; - } -} - -bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { - if (prim == nullptr) { - return false; - } - const int attr_tensor_type = attr_tensor.data_type(); - const std::string &tensor_buf = attr_tensor.raw_data(); - std::vector shape; - auto ret = EOK; - if (attr_tensor.dims_size() != 0) { - for (int i = 0; i < attr_tensor.dims_size(); ++i) { - shape.push_back(attr_tensor.dims(i)); - } - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - tensor::TensorPtr tensor_info = - std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape_vector); - auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); - ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); - if (EOK != ret) { - MS_LOG(ERROR) << "memcpy_s error"; - return false; - } - prim->set_attr(attr_name, MakeValue(tensor_info)); - } else { - if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) { - size_t data_size = sizeof(double); - double attr_value = 0.0; - ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); - if (EOK != ret) { - MS_LOG(ERROR) << "memcpy_s error"; - return false; - } - prim->set_attr(attr_name, MakeValue(attr_value)); - } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) { - size_t data_size = sizeof(int64_t); - int64_t attr_value = 0; - ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); - if (EOK != ret) { - MS_LOG(ERROR) << "memcpy_s error"; - return false; - } - prim->set_attr(attr_name, MakeValue(attr_value)); - } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) { - size_t data_size = sizeof(bool); - bool attr_value = false; - ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); - if (EOK != ret) { - MS_LOG(ERROR) << "memcpy_s error"; - return false; - } - prim->set_attr(attr_name, MakeValue(attr_value)); - } - } - return ret == EOK; -} - -bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { - if (prim == nullptr) { - return false; - } - const std::string &attr_name = attr_proto.name(); - if (!attr_proto.has_ref_attr_name()) { - MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; - return false; - } - const std::string &ref_attr_name = attr_proto.ref_attr_name(); - if (ref_attr_name.empty()) { - MS_LOG(ERROR) << "ref_attr_name is empty"; - return false; - } - string type = ""; - std::size_t pos(0); - if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("scalar:").length() - 1); - } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("type:").length() - 1); - } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("tensor:").length() - 1); - } - std::unordered_map kv; - for (int i = 0; i < attr_proto.tensors_size(); i++) { - const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); - switch (kParseTypeSwitchMap[type]) { - case FORM_PARSE_TYPE: { - return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); - } - case FORM_PARSE_SCALAR: { - auto res = ObtainCNodeAttrInScalarForm(attr_tensor); - kv.insert(std::pair(attr_tensor.name(), res)); - break; - } - case FORM_PARSE_TENSOR: { - return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); - } - default: - MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; - return false; - } - } - if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { - if (kv.size() == 1) { - auto iter = kv.begin(); - prim->AddAttr(attr_name, iter->second); - } else { - auto res = ParserScalarAttrValue(ref_attr_name, kv); - prim->AddAttr(attr_name, res); - } - } - return true; -} - -bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - std::vector shape; - for (int i = 0; i < attr_tensor.dims_size(); ++i) { - shape.push_back(attr_tensor.dims(i)); - } - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - ParamValueLitePtr param_value = std::make_shared(); - param_value->set_tensor_shape(shape_vector); - param_value->set_tensor_type(kDefaultValueSwitchMap[attr_tensor_type]); - const std::string &tensor_buf = attr_tensor.raw_data(); - auto tensor_data = new (std::nothrow) char[tensor_buf.size()]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "Tensor_data is nullptr"; - return false; - } - auto ret = memcpy_s(tensor_data, tensor_buf.size(), tensor_buf.data(), tensor_buf.size()); - if (ret != EOK) { - delete[] tensor_data; - MS_LOG(ERROR) << "Memcpy error: " << ret; - return false; - } - param_value->SetTensorData(tensor_data, tensor_buf.size()); - auto new_value_node = NewValueNode(MakeValue(param_value)); - if (new_value_node == nullptr) { - MS_LOG(ERROR) << "Make valuenode fail"; - return false; - } - auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); - std::vector shape_vector_int64; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector_int64), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector_int64); - new_value_node->set_abstract(abstract_tensor); - anfnode_build_map_[value_node_name] = new_value_node; - return true; -} - -bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; - return false; - } - auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); - abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); - new_value_node->set_abstract(abs_type); - anfnode_build_map_[value_node_name] = new_value_node; - return true; -} - -bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name, - const onnx::AttributeProto &attr_proto) { - if (!attr_proto.has_ref_attr_name()) { - MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; - return false; - } - const std::string &ref_attr_name = attr_proto.ref_attr_name(); - if (ref_attr_name.empty()) { - MS_LOG(ERROR) << "ref_attr_name is empty"; - return false; - } - string type = ""; - std::size_t pos(0); - if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("scalar:").length() - 1); - } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("type:").length() - 1); - } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { - type = ref_attr_name.substr(pos, string("tensor:").length() - 1); - } - std::unordered_map kv; - for (int i = 0; i < attr_proto.tensors_size(); i++) { - const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); - switch (kParseTypeSwitchMap[type]) { - case FORM_PARSE_TYPE: { - return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); - } - case FORM_PARSE_SCALAR: { - auto res = ObtainCNodeAttrInScalarForm(attr_tensor); - kv.insert(std::pair(attr_tensor.name(), res)); - break; - } - case FORM_PARSE_TENSOR: { - return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); - } - default: - MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; - return false; - } - } - - ValueNodePtr new_value_node; - if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { - if (kv.size() == 1) { - auto iter = kv.begin(); - new_value_node = NewValueNode(iter->second); - new_value_node->set_abstract(iter->second->ToAbstract()); - } else { - auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); - new_value_node = NewValueNode(value_ptr); - new_value_node->set_abstract(value_ptr->ToAbstract()); - } - anfnode_build_map_[value_node_name] = new_value_node; - } - return true; -} - -bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { - const std::string &value_node_name = node_proto.output(0); - const onnx::AttributeProto &attr_proto = node_proto.attribute(0); - if (!attr_proto.has_ref_attr_name()) { - MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; - return false; - } - return GetAttrValueForValueNode(value_node_name, attr_proto); -} - -std::unordered_map AnfImporterFromMindir::GetAbstractForCNode( - const onnx::AttributeProto &attr_proto) { - std::unordered_map kv; - for (int i = 0; i < attr_proto.tensors_size(); i++) { - std::vector shape; - const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); - for (int j = 0; j < attr_tensor.dims_size(); ++j) { - shape.push_back(attr_tensor.dims(j)); - } - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - kv.insert(std::pair(attr_tensor.name(), abstract_tensor)); - } - return kv; -} - -CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::NodeProto &node_proto, - const schema::QuantType &quantType) { - static bool interrupt = false; - if (outputFuncGraph == nullptr) { - MS_LOG(ERROR) << "output funcgraph is nullptr"; - return nullptr; - } - if (!node_proto.has_op_type()) { - MS_LOG(ERROR) << "Get CNode op_type failed!"; - return nullptr; - } - const std::string &node_name = node_proto.output(0); - const std::string &fullname_with_scope = node_proto.domain(); - // const std::string &node_type = node_proto.op_type(); - PrimitivePtr prim; - // NOTE: can not find OpPrimCRegister - // auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap(); - // if (op_primc_fns.find(node_type) != op_primc_fns.end()) { - // prim = op_primc_fns[node_type](); - // } else { - // prim = std::make_shared(node_type); - // prim->set_instance_name(node_type); - // } - if (prim == nullptr) { - MS_LOG(ERROR) << "new primitive failed"; - return nullptr; - } - std::unordered_map kv; - string shape_ref_attr_name; - for (int i = 0; i < node_proto.attribute_size(); ++i) { - const onnx::AttributeProto &attr_proto = node_proto.attribute(i); - if (attr_proto.ref_attr_name().find("shape:") != string::npos) { - shape_ref_attr_name = attr_proto.ref_attr_name(); - kv = GetAbstractForCNode(attr_proto); - continue; - } - if (!GetAttrValueForCNode(prim, attr_proto)) { - MS_LOG(ERROR) << "Get CNode attr failed!"; - return nullptr; - } - } - - std::vector inputs; - inputs.clear(); - for (int i = 0; i < node_proto.input_size(); ++i) { - const std::string &input_name = node_proto.input(i); - if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { - if (!interrupt) { - MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; - interrupt = true; - } - inputs.push_back(nullptr); - } else { - inputs.push_back(anfnode_build_map_[input_name]); - } - } - CNodePtr cnode_ptr = outputFuncGraph->NewCNode(prim, inputs); - if (cnode_ptr == nullptr) { - interrupt = true; - MS_LOG(ERROR) << "funcgraph new cnode failed"; - return nullptr; - } - if (kv.empty()) { - AbstractBasePtrList elem; - for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { - elem.push_back(cnode_ptr->input(index)->abstract()); - } - cnode_ptr->set_abstract(std::make_shared(elem)); - } else if (1 == kv.size()) { - auto iter = kv.begin(); - cnode_ptr->set_abstract(iter->second); - } else { - auto abstract = ParserAttrShape(shape_ref_attr_name, kv); - cnode_ptr->set_abstract(abstract); - } - - cnode_ptr->set_fullname_with_scope(fullname_with_scope); - anfnode_build_map_[node_name] = cnode_ptr; - return cnode_ptr; -} - -bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { - if (outputFuncGraph == nullptr || cnode_ptr == nullptr) { - MS_LOG(ERROR) << "output funcgraph or cnode is nullptr"; - return false; - } - std::vector inputs; - if (importProto.output_size() > 1) { - inputs.clear(); - auto make_tuple_prim = std::make_shared(); - inputs.push_back(NewValueNode(make_tuple_prim)); - AbstractBasePtrList elem; - for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { - const onnx::ValueInfoProto &output_node = importProto.output(out_size); - const std::string &out_tuple = output_node.name(); - inputs.push_back(anfnode_build_map_[out_tuple]); - if (anfnode_build_map_[out_tuple] == nullptr) { - MS_LOG(ERROR) << "AnfNode is nullptr"; - return false; - } - elem.push_back(anfnode_build_map_[out_tuple]->abstract()); - } - auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); - if (maketuple_ptr == nullptr) { - MS_LOG(ERROR) << "maketuple_ptr is nullptr"; - return false; - } - maketuple_ptr->set_abstract(std::make_shared(elem)); - inputs.clear(); - auto return_prim = std::make_shared(); - inputs.push_back(NewValueNode(return_prim)); - inputs.push_back(maketuple_ptr); - auto return_node = outputFuncGraph->NewCNode(inputs); - if (return_node == nullptr) { - MS_LOG(ERROR) << "funcgraph new cnode failed"; - return false; - } - outputFuncGraph->set_return(return_node); - MS_LOG(INFO) << "Construct funcgraph finined, all success."; - } else { - const onnx::ValueInfoProto &output_node = importProto.output(0); - const onnx::TypeProto &output_typeproto = output_node.type(); - int output_type = output_typeproto.tensor_type().elem_type(); - std::vector output_shape; - for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { - output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); - } - std::vector shape_vector; - (void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - inputs.clear(); - auto return_prim = std::make_shared(); - inputs.push_back(NewValueNode(return_prim)); - inputs.push_back(cnode_ptr); - auto return_node = outputFuncGraph->NewCNode(inputs); - if (return_node == nullptr) { - MS_LOG(ERROR) << "funcgraph new cnode failed"; - return false; - } - return_node->set_abstract(abstract_tensor); - outputFuncGraph->set_return(return_node); - MS_LOG(INFO) << "Construct funcgraph finined, all success!"; - } - return true; -} - -int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { - if (outputFuncGraph == nullptr) { - MS_LOG(ERROR) << "funcgraph is nullptr"; - return RET_NULL_PTR; - } - MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); - CNodePtr cnode_ptr = nullptr; - CNodePtr last_cnode_ptr = nullptr; - int status = RET_OK; - NoSupportOp::GetInstance()->SetFmkType("MINDIR"); - for (int i = 0; i < importProto.node_size(); ++i) { - const onnx::NodeProto &node_proto = importProto.node(i); - const std::string &node_type = node_proto.op_type(); - MS_LOG(INFO) << "parse op : " << node_type; - if (node_type == kConstantValueNode) { - if (status == RET_OK && !BuildValueNodeForFuncGraph(node_proto)) { - MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; - status = RET_ERROR; - } - continue; - } - cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); - if (cnode_ptr == nullptr) { - MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; - return RET_ERROR; - } - - auto primitive_c = GetValueNode>(cnode_ptr->input(0)); - if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is nullptr"; - return RET_ERROR; - } - } - if (status != RET_OK) { - return status; - } - if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { - MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; - status = RET_ERROR; - } - return status; -} - -int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { - if (outputFuncGraph == nullptr) { - MS_LOG(ERROR) << "fundgraph is nullptr"; - return RET_NULL_PTR; - } - GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); - if (debug_info_ptr == nullptr) { - MS_LOG(ERROR) << "funcgraph's debug info is nullptr"; - return RET_NULL_PTR; - } - if (importProto.has_name()) { - debug_info_ptr->set_name(importProto.name()); - } else { - MS_LOG(INFO) << "FuncGraph under converting has not name!"; - } - - auto status = ImportParametersForGraph(outputFuncGraph, importProto); - if (status != RET_OK) { - return status; - } - return ImportNodesForGraph(outputFuncGraph, importProto, quantType); -} - -int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { - if (!model_proto.has_producer_name()) { - MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; - return RET_GRAPH_FILE_ERR; - } - producer_name_ = model_proto.producer_name(); - - if (!model_proto.has_model_version()) { - MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; - return RET_GRAPH_FILE_ERR; - } - model_version_ = model_proto.model_version(); - - if (!model_proto.has_ir_version()) { - MS_LOG(ERROR) << "Parse model version from pb file failed!"; - return RET_GRAPH_FILE_ERR; - } - ir_version_ = model_proto.ir_version(); - return RET_OK; -} - -int AnfImporterFromMindir::Import(const converter::Flags *flag) { -#if SUPPORT_TRAIN - func_graph_ = LoadMindIR(flag->modelFile); - if (func_graph_ != nullptr) { - return RET_OK; - } else { - MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; - } -#endif - onnx_model_ = ReadOnnxFromBinary(flag->modelFile); - if (onnx_model_ == nullptr) { - MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; - func_graph_ = LoadMindIR(flag->modelFile); - if (func_graph_ == nullptr) { - MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; - return RET_GRAPH_FILE_ERR; - } - return RET_OK; - } - FuncGraphPtr dstGraph = std::make_shared(); - if (dstGraph == nullptr) { - MS_LOG(ERROR) << "funcgraph is nullptr"; - return RET_NULL_PTR; - } - int status = ParseModelConfigureInfo(*onnx_model_); - if (status != RET_OK) { - MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; - return status; - } - auto quantType = flag->quantType; - const onnx::GraphProto &graphBuild = onnx_model_->graph(); - status = BuildFuncGraph(dstGraph, graphBuild, quantType); - if (status != RET_OK) { - MS_LOG(ERROR) << "Build funcgraph failed!"; - func_graph_ = nullptr; - return status; - } - func_graph_ = dstGraph; - MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; - return RET_OK; -} - -onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) { - auto onnx_model = new (std::nothrow) onnx::ModelProto; - if (onnx_model == nullptr) { - MS_LOG(ERROR) << "New onnx ModelProto failed!"; - return nullptr; - } - if (RET_OK != ValidateFileStr(model_path, ".mindir")) { - MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; - return nullptr; - } - if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { - MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model"; - return nullptr; - } - return onnx_model; -} - -FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; } -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/import_from_mindir.h b/mindspore/lite/tools/anf_importer/import_from_mindir.h deleted file mode 100644 index f743f473ab..0000000000 --- a/mindspore/lite/tools/anf_importer/import_from_mindir.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ -#define MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ - -#include -#include -#include -#include - -#include "include/errorcode.h" -#include "proto/onnx.pb.h" -#include "tools/converter/converter_context.h" -#include "tools/anf_importer/anf_importer.h" -#include "abstract/abstract_value.h" - -namespace mindspore::lite { -class AnfImporterFromMindir : public AnfImporter { - public: - AnfImporterFromMindir() = default; - - ~AnfImporterFromMindir() override { delete onnx_model_; } - - static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); - - FuncGraphPtr GetResult() override; - - int Import(const converter::Flags *flag) override; - - private: - int ConverterConstTensor() override { return RET_ERROR; }; - int ConverterCNode() override { return RET_ERROR; }; - int AddReturnCNode() override { return RET_ERROR; }; - int ParseModelConfigureInfo(const onnx::ModelProto &model_proto); - int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType); - int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); - int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType); - int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); - CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, - const schema::QuantType &quantType); - bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const CNodePtr &cnode_ptr); - static bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); - static bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor); - static ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); - static bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor); - bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); - bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto); - bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - static std::unordered_map GetAbstractForCNode( - const onnx::AttributeProto &attr_proto); - - private: - std::string producer_name_; - int model_version_{}; - int ir_version_{}; - std::unordered_map anfnode_build_map_; - std::map default_para_map_; - onnx::ModelProto *onnx_model_; - FuncGraphPtr func_graph_; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h index 71c199b9a4..2bd51ee530 100644 --- a/mindspore/lite/tools/common/tensor_util.h +++ b/mindspore/lite/tools/common/tensor_util.h @@ -37,7 +37,6 @@ using schema::QuantParamT; using schema::TensorT; using schema::Format::Format_NCHW; using schema::Format::Format_NHWC; -using STATUS = int; std::unique_ptr GetTensorQuantParam(const std::unique_ptr &tensor); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index f6a8f610bc..03e27dfc70 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -64,7 +64,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/primitive_adjust_pass.cc ) -add_subdirectory(../anf_importer anf_importer) add_subdirectory(../anf_exporter anf_exporter) add_subdirectory(parser/caffe) add_subdirectory(parser/tflite) @@ -158,7 +157,6 @@ target_link_libraries(converter_lite PRIVATE tf_parser_mid caffe_parser_mid onnx_parser_mid - anf_importer_mid anf_exporter_mid graph_pass_mid fusion_mid diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index c3a70a755c..cc8598ff90 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -17,12 +17,7 @@ #include "tools/converter/converter.h" #include #include -#include #include "tools/converter/converter_flags.h" -#include "src/common/common.h" -#include "src/common/file_utils.h" -#include "ir/func_graph.h" - #include "src/common/log_adapter.h" #include "tools/common/storage.h" #include "parser/caffe/caffe_converter.h" @@ -30,68 +25,48 @@ #include "parser/onnx/onnx_converter.h" #include "parser/tf/tf_converter.h" #include "tools/anf_exporter/anf_exporter.h" -#include "tools/anf_importer/import_from_mindir.h" -#include "proto/onnx.pb.h" -#include "tools/converter/quantizer/post_training_quantizer.h" -#include "tools/converter/quantizer/quant_cast.h" #include "include/version.h" namespace mindspore { namespace lite { using FmkType = converter::FmkType; -static const char *DELIM_SLASH = "/"; -Converter::Converter() { - this->transform = new GraphDefTransform; - this->anfTransform = new AnfTransform; -} - -Converter::~Converter() { - delete modelParser; - delete modelImporter; - delete transform; - delete anfTransform; -} - -class MindsporeImporter : public Converter { - public: - MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); } - - ~MindsporeImporter() override = default; -}; - -MetaGraphT *Converter::Convert(const converter::Flags *flag) { - // parse the model and weight file to generate inference data structure - FuncGraphPtr graph = nullptr; - if (flag->fmk == converter::FmkType_MS) { - MS_ASSERT(nullptr != modelImporter); - int status = modelImporter->Import(flag); - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - graph = modelImporter->GetResult(); - if (graph == nullptr) { +std::unique_ptr Converter::CreateConverter(converter::FmkType fmk) { + switch (fmk) { + case FmkType::FmkType_MS: + return std::make_unique(); + case FmkType::FmkType_CAFFE: + return std::make_unique(); + case FmkType::FmkType_TFLITE: + return std::make_unique(); + case FmkType::FmkType_ONNX: + return std::make_unique(); + case FmkType::FmkType_TF: + return std::make_unique(); + default: { return nullptr; } - graph->set_attr("graph_name", MakeValue("main_graph")); - graph->set_attr("fmk", MakeValue(static_cast(converter::FmkType_MS))); - } else { - MS_ASSERT(nullptr != modelParser); - const std::string modelFile = flag->modelFile; - const std::string weightFile = flag->weightFile; - graph = modelParser->Parse(modelFile, weightFile, flag->quantType); } +} + +MetaGraphT *Converter::Convert(const std::unique_ptr &flag) { + if (flag == nullptr) { + MS_LOG(ERROR) << "Input flag is nullptr"; + return nullptr; + } + auto graph = BuildFuncGraph(flag->modelFile, flag->weightFile, flag->quantType); if (graph == nullptr) { MS_LOG(ERROR) << "Parser/Import model return nullptr"; return nullptr; } - MS_LOG(INFO) << "import success"; - - graph = anfTransform->Transform(graph, flag); + // funcgraph compile + graph = funcgraph_transform_->Transform(graph, flag.get()); if (graph == nullptr) { MS_LOG(ERROR) << "Transform anf graph return nullptr"; return nullptr; } MS_LOG(INFO) << "Run anfTransform success"; - // anf -- fb + // protobuf -> flatbuf auto meta_graph = Export(graph); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta graph return nullptr"; @@ -99,91 +74,75 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { } MS_LOG(INFO) << "export success"; - // transform - transform->SetGraphDef(meta_graph); - auto status = transform->Transform(*flag); + // metagraph compile + metagraph_transform_->SetGraphDef(meta_graph); + auto status = metagraph_transform_->Transform(*flag); if (status != RET_OK) { MS_LOG(ERROR) << "Transform meta graph failed " << status; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - MS_LOG(INFO) << "run fbTransform success"; - return meta_graph; } int RunConverter(int argc, const char **argv) { + std::ostringstream oss; std::unique_ptr flags(new (std::nothrow) converter::Flags); if (flags == nullptr) { - MS_LOG(ERROR) << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED); - std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED) << std::endl; + oss.clear(); + oss << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED); + MS_LOG(ERROR) << oss.str(); + std::cout << oss.str() << std::endl; return RET_MEMORY_FAILED; } auto status = flags->Init(argc, argv); if (status != RET_OK) { if (status != RET_SUCCESS_EXIT) { - MS_LOG(ERROR) << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status) << std::endl; - std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status) << std::endl; + oss.clear(); + oss << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status); + MS_LOG(ERROR) << oss.str(); + std::cout << oss.str() << std::endl; } - std::cout << GetErrorInfo(status) << std::endl; return status; } // Load graph - std::string modelName = flags->modelFile.substr(flags->modelFile.find_last_of(DELIM_SLASH) + 1); - MS_LOG(INFO) << "start reading model file"; - - auto fb_graph = new (std::nothrow) MetaGraphT; - switch (flags->fmk) { - case FmkType::FmkType_MS: { - MindsporeImporter mindsporeImporter; - fb_graph = mindsporeImporter.Convert(flags.get()); - break; - } - case FmkType::FmkType_CAFFE: { - CaffeConverter caffeConverter; - fb_graph = caffeConverter.Convert(flags.get()); - } break; - case FmkType::FmkType_TFLITE: { - TfliteConverter tfLiteConverter; - fb_graph = tfLiteConverter.Convert(flags.get()); - } break; - case FmkType::FmkType_ONNX: { - OnnxConverter onnxConverter; - fb_graph = onnxConverter.Convert(flags.get()); - } break; - case FmkType::FmkType_TF: { - TFConverter tfConverter; - fb_graph = tfConverter.Convert(flags.get()); - } break; - default: { - MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " - << GetErrorInfo(RET_INPUT_PARAM_INVALID); - std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " - << GetErrorInfo(RET_INPUT_PARAM_INVALID) << std::endl; - return RET_INPUT_PARAM_INVALID; - } + MS_LOG(DEBUG) << "start reading model file"; + auto converter = Converter::CreateConverter(flags->fmk); + if (converter == nullptr) { + oss.clear(); + oss << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " + << GetErrorInfo(RET_INPUT_PARAM_INVALID); + MS_LOG(ERROR) << oss.str(); + std::cout << oss.str() << std::endl; + return RET_INPUT_PARAM_INVALID; } + auto meta_graph = converter->Convert(flags); NoSupportOp::GetInstance()->PrintOps(); status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); - if (fb_graph == nullptr) { - MS_LOG(ERROR) << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); - std::cout << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status) << std::endl; + if (meta_graph == nullptr) { + oss.clear(); + oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); + MS_LOG(ERROR) << oss.str(); + std::cout << oss.str() << std::endl; return status; } // save graph to file - Storage storage; - fb_graph->version = Version(); - status = storage.Save(*fb_graph, flags->outputFile); + meta_graph->version = Version(); + status = Storage::Save(*meta_graph, flags->outputFile); if (status != RET_OK) { - MS_LOG(ERROR) << "SAVE GRAPH FAILED:" << status << " " << GetErrorInfo(status); - std::cout << "SAVE GRAPH FAILED:" << status << " " << GetErrorInfo(status) << std::endl; + oss.clear(); + oss << "SAVE GRAPH FAILED:" << status << " " << GetErrorInfo(status); + MS_LOG(ERROR) << oss.str(); + std::cout << oss.str() << std::endl; return status; } - delete fb_graph; - MS_LOG(INFO) << "CONVERT RESULT SUCCESS:" << status; - std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl; + delete meta_graph; + oss.clear(); + oss << "CONVERT RESULT SUCCESS:" << status; + MS_LOG(INFO) << oss.str(); + std::cout << oss.str() << std::endl; return status; } } // namespace lite diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index 472bd67a5e..dea11d9cf1 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -22,24 +22,41 @@ #include "schema/inner/model_generated.h" #include "tools/converter/graphdef_transform.h" #include "tools/converter/model_parser.h" -#include "tools/anf_importer/anf_importer.h" #include "tools/converter/converter_flags.h" #include "tools/converter/anf_transform.h" #include "tools/converter/converter_context.h" +#include "load_mindir/load_model.h" namespace mindspore { namespace lite { class Converter { public: - Converter(); - virtual ~Converter(); - virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags); + static std::unique_ptr CreateConverter(converter::FmkType fmk); + + virtual ~Converter() = default; + + virtual schema::MetaGraphT *Convert(const std::unique_ptr &flag); + + virtual FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, + schema::QuantType quant_type) = 0; protected: - ModelParser *modelParser = nullptr; - AnfImporter *modelImporter = nullptr; - GraphDefTransform *transform = nullptr; - AnfTransform *anfTransform = nullptr; + Converter() = default; + + std::unique_ptr metagraph_transform_ = std::make_unique(); + std::unique_ptr funcgraph_transform_ = std::make_unique(); +}; + +class MindsporeImporter : public Converter { + public: + MindsporeImporter() = default; + + ~MindsporeImporter() override = default; + + FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, + schema::QuantType quant_type) override { + return LoadMindIR(model_file); + } }; int RunConverter(int argc, const char **argv); diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index 18b914b597..d7eb1c4163 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -35,13 +35,7 @@ class ModelParser { virtual ~ModelParser() = default; virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { - return nullptr; - } - - protected: - virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type = QuantType_QUANT_NONE) = 0; + const QuantType &quant_type) = 0; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc deleted file mode 100644 index a63d5602f4..0000000000 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/parser/caffe/caffe_converter.h" -#include "tools/converter/parser/caffe/caffe_model_parser.h" - -namespace mindspore::lite { -CaffeConverter::CaffeConverter() { modelParser = new CaffeModelParser(); } -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h index 0c0367b32c..11b823609e 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h @@ -21,13 +21,20 @@ #include #include "tools/converter/converter.h" #include "tools/converter/graphdef_transform.h" +#include "tools/converter/parser/caffe/caffe_model_parser.h" namespace mindspore::lite { class CaffeConverter : public Converter { public: - CaffeConverter(); + CaffeConverter() = default; ~CaffeConverter() override = default; + + FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, + schema::QuantType quant_type) override { + CaffeModelParser parser; + return parser.Parse(model_file, weight_file, quant_type); + } }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 789f854052..09339c36b5 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -26,6 +26,8 @@ #include "ops/return.h" #include "ops/make_tuple.h" #include "ops/tuple_get_item.h" +#include "ir/func_graph.h" +#include "tools/converter/converter_flags.h" namespace mindspore::lite { CaffeModelParser::CaffeModelParser() = default; @@ -426,10 +428,4 @@ bool CaffeModelParser::IsSkipedLayer(const caffe::LayerParameter &layer) { } return layer.include_size() == 1 && layer.include(0).phase() == caffe::TRAIN; } - -MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { - return nullptr; -} - } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index 23fc0da334..2098e4f004 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -35,9 +35,6 @@ class CaffeModelParser : public ModelParser { FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) override; - MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; - private: STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc index dc186084f3..5ee5fab618 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -20,6 +20,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "ops/constant.h" +#include "src/param_value_lite.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc deleted file mode 100644 index c25c778964..0000000000 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/parser/onnx/onnx_converter.h" -#include "tools/converter/parser/onnx/onnx_model_parser.h" - -namespace mindspore { -namespace lite { -OnnxConverter::OnnxConverter() { modelParser = new OnnxModelParser(); } - -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h index 3344ba7b13..4e253ed308 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h @@ -21,14 +21,21 @@ #include #include "tools/converter/converter.h" #include "tools/converter/graphdef_transform.h" +#include "tools/converter/parser/onnx/onnx_model_parser.h" namespace mindspore { namespace lite { class OnnxConverter : public Converter { public: - OnnxConverter(); + OnnxConverter() = default; ~OnnxConverter() override = default; + + FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, + schema::QuantType quant_type) override { + OnnxModelParser parser; + return parser.Parse(model_file, weight_file, quant_type); + } }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 1ca9550e91..8c53e31719 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -26,6 +26,9 @@ #include "ops/return.h" #include "ops/make_tuple.h" #include "ops/tuple_get_item.h" +#include "ir/func_graph.h" +#include "src/param_value_lite.h" +#include "tools/converter/converter_flags.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index c1243d4602..c3eecaa934 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -41,16 +41,10 @@ class OnnxModelParser : public ModelParser { ~OnnxModelParser() override = default; - MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override { - return nullptr; - } - FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) override; static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); - static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, - const ParamValueLitePtr ¶m_value_lite); + static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, const ParamValueLitePtr ¶m_value); private: STATUS InitOriginModel(const std::string &model_file); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_converter.cc b/mindspore/lite/tools/converter/parser/tf/tf_converter.cc deleted file mode 100644 index 13aff62310..0000000000 --- a/mindspore/lite/tools/converter/parser/tf/tf_converter.cc +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/converter/parser/tf/tf_converter.h" -#include "tools/converter/parser/tf/tf_model_parser.h" -namespace mindspore { -namespace lite { -TFConverter::TFConverter() { modelParser = new TFModelParser(); } -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_converter.h b/mindspore/lite/tools/converter/parser/tf/tf_converter.h index 6e1a685c05..108d65f6ed 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_converter.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_converter.h @@ -18,13 +18,21 @@ #include #include #include "tools/converter/converter.h" +#include "tools/converter/parser/tf/tf_model_parser.h" + namespace mindspore { namespace lite { class TFConverter : public Converter { public: - TFConverter(); + TFConverter() = default; + + ~TFConverter() override = default; - ~TFConverter() = default; + FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, + schema::QuantType quant_type) override { + TFModelParser parser; + return parser.Parse(model_file, weight_file, quant_type); + } }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 18e534f225..54e76b77e6 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -30,6 +30,7 @@ #include "ops/tuple_get_item.h" #include "ops/while.h" #include "ir/anf.h" +#include "tools/converter/converter_flags.h" namespace mindspore { namespace lite { @@ -635,12 +636,6 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map &input_names, const std::map &tf_node_map, diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index a779c6ca2e..41e40778a1 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -28,7 +28,7 @@ #include "securec/include/securec.h" #include "tools/common/tensor_util.h" #include "tools/converter/model_parser.h" -#include "mindspore/lite/src/param_value_lite.h" +#include "src/param_value_lite.h" namespace mindspore { namespace lite { @@ -39,10 +39,6 @@ class TFModelParser : public ModelParser { FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); - protected: - schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType = QuantType_QUANT_NONE) override; - private: STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, const ParamValueLitePtr ¶m_value); STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc deleted file mode 100644 index 34f28d7306..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/parser/tflite/tflite_converter.h" -#include "tools/converter/parser/tflite/tflite_model_parser.h" - -namespace mindspore::lite { -TfliteConverter::TfliteConverter() { modelParser = new TfliteModelParser(); } -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h index eba2150e7c..cb87ea9cc4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h @@ -22,13 +22,20 @@ #include #include "tools/converter/converter.h" #include "tools/converter/graphdef_transform.h" +#include "tools/converter/parser/tflite/tflite_model_parser.h" namespace mindspore::lite { class TfliteConverter : public Converter { public: - TfliteConverter(); + TfliteConverter() = default; ~TfliteConverter() override = default; + + FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, + schema::QuantType quant_type) override { + TfliteModelParser parser; + return parser.Parse(model_file, weight_file, quant_type); + } }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 724e995e5c..272fdbb939 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -19,15 +19,16 @@ #include #include #include +#include "tools/converter/converter_flags.h" #include "src/param_value_lite.h" #include "src/common/file_utils.h" #include "ops/return.h" #include "ops/make_tuple.h" #include "ops/tuple_get_item.h" #include "ops/primitive_c.h" +#include "ir/func_graph.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *model_path) { size_t size = 0; tflite_model_buf_ = ReadFile(model_path, &size); @@ -453,10 +454,4 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const } return RET_OK; } - -MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { - return nullptr; -} -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 303e656a11..1d800b3489 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -34,8 +34,6 @@ class TfliteModelParser : public ModelParser { FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) override; - MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; private: std::unordered_map nodes_;