From a4e0b609a6aeb927b72de03be211b511a8563749 Mon Sep 17 00:00:00 2001 From: wjm Date: Sun, 25 Apr 2021 11:08:52 +0800 Subject: [PATCH 1/3] run prototype pass --- parser/caffe/caffe_parser.cc | 5 +++ parser/common/CMakeLists.txt | 1 + parser/common/prototype_pass_manager.cc | 41 +++++++++++++++++++++++++ parser/common/prototype_pass_manager.h | 36 ++++++++++++++++++++++ parser/onnx/onnx_parser.cc | 3 ++ parser/tensorflow/tensorflow_parser.cc | 10 ++++++ tests/ut/parser/CMakeLists.txt | 2 ++ 7 files changed, 98 insertions(+) create mode 100644 parser/common/prototype_pass_manager.cc create mode 100644 parser/common/prototype_pass_manager.h diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index fe0a725..7010fae 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -44,6 +44,7 @@ #include "parser/caffe/caffe_op_parser.h" #include "parser/common/op_parser_factory.h" #include "parser/common/pre_checker.h" +#include "parser/common/prototype_pass_manager.h" #include "framework/omg/parser/parser_types.h" #include "parser/common/model_saver.h" #include "parser/common/acl_graph_parser_util.h" @@ -1503,6 +1504,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co GE_CHK_BOOL_RET_STATUS((proto_message.layer_size() != 0), FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), + "Run ProtoType Pass Failed"); // Set network name GE_IF_BOOL_EXEC((proto_message.has_name()), graph->SetName(proto_message.name())); @@ -1710,6 +1713,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap ErrorManager::GetInstance().ATCReportErrMessage("E11022"); return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), + "Run ProtoType Pass Failed"); // Set network name GE_IF_BOOL_EXEC((proto_message.has_name() && !proto_message.name().empty()), graph->SetName(proto_message.name())); diff --git a/parser/common/CMakeLists.txt b/parser/common/CMakeLists.txt index 6dc8b6d..9a45d34 100644 --- a/parser/common/CMakeLists.txt +++ b/parser/common/CMakeLists.txt @@ -26,6 +26,7 @@ set(SRC_LIST "thread_pool.cc" "parser_utils.cc" "auto_mapping_subgraph_io_index_func.cc" + "prototype_pass_manager.cc" ) ############ libparser_common.so ############ diff --git a/parser/common/prototype_pass_manager.cc b/parser/common/prototype_pass_manager.cc new file mode 100644 index 0000000..ceca35c --- /dev/null +++ b/parser/common/prototype_pass_manager.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "prototype_pass_manager.h" + +#include "framework/common/debug/ge_log.h" + +namespace ge { +ProtoTypePassManager &ProtoTypePassManager::Instance() { + static ProtoTypePassManager instance; + return instance; +} + +Status ProtoTypePassManager::Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type) { + const auto &pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(fmk_type); + for (const auto &pass_item : pass_vec) { + std::string pass_name = pass_item.first; + std::unique_ptr pass = std::unique_ptr(pass_item.second()); + Status ret = pass->Run(message); + if (ret != SUCCESS) { + GELOGE(FAILED, "Run ProtoType pass:%s failed", pass_name.c_str()); + return ret; + } + GELOGD("Run ProtoType pass:%s success", pass_name.c_str()); + } + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/parser/common/prototype_pass_manager.h b/parser/common/prototype_pass_manager.h new file mode 100644 index 0000000..d9c2f9f --- /dev/null +++ b/parser/common/prototype_pass_manager.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_PROTOTYPE_PASS_MANAGER_H +#define PARSER_PROTOTYPE_PASS_MANAGER_H + +#include "register/prototype_pass_registry.h" + +namespace ge { +class ProtoTypePassManager { + public: + static ProtoTypePassManager &Instance(); + + Status Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type); + + ~ProtoTypePassManager() = default; + + private: + ProtoTypePassManager() = default; +}; +} // namespace ge + +#endif // PARSER_PROTOTYPE_PASS_MANAGER_H diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 2de88c5..abddc76 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -35,6 +35,7 @@ #include "parser/common/acl_graph_parser_util.h" #include "parser/common/model_saver.h" #include "parser/common/parser_utils.h" +#include "parser/common/prototype_pass_manager.h" #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" #include "register/register_fmk_types.h" @@ -838,6 +839,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP ClearMembers(); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), + "Run ProtoType Pass Failed"); // 2. Get all inializer. std::map initializer_name_tensor; for (int i = 0; i < onnx_graph.initializer_size(); i++) { diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 4e7c59f..74ce961 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -41,6 +41,7 @@ #include "parser/common/parser_fp16_t.h" #include "parser/common/pass_manager.h" #include "parser/common/pre_checker.h" +#include "parser/common/prototype_pass_manager.h" #include "parser/common/thread_pool.h" #include "parser/common/parser_utils.h" #include "parser/common/util.h" @@ -1146,6 +1147,9 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); } + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), + "Run ProtoType Pass Failed"); + shared_ptr scope_graph = nullptr; Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); if (ret != SUCCESS) { @@ -1374,6 +1378,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro // Make a copy for operation without modifying the original graph def. domi::tensorflow::GraphDef graph_def = *ori_graph; + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), + "Run ProtoType Pass Failed"); + shared_ptr scope_graph = nullptr; Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); if (ret != SUCCESS) { @@ -2216,6 +2223,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, domi::tensorflow::GraphDef *graph_def = &graph_def_operation; GELOGI("[TF Parser] graph def version:%d", graph_def->version()); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(graph_def, domi::TENSORFLOW), + "Run ProtoType Pass Failed"); + shared_ptr scope_graph = nullptr; Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph); if (ret != SUCCESS) { diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 9f0578d..7130e20 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -177,6 +177,7 @@ set(REGISTER_SRC_FILES "${PARSER_DIR}/metadef/register/scope/scope_pattern.cc" "${PARSER_DIR}/metadef/register/scope/scope_util.cc" "${PARSER_DIR}/metadef/register/tensor_assign.cpp" + "${PARSER_DIR}/metadef/register/prototype_pass_registry.cc" ) # include directories @@ -246,6 +247,7 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/common/pass_manager.cc" "${PARSER_DIR}/parser/common/pre_checker.cc" "${PARSER_DIR}/parser/common/proto_file_parser.cc" + "${PARSER_DIR}/parser/common/prototype_pass_manager.cc" "${PARSER_DIR}/parser/common/register_tbe.cc" "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" "${PARSER_DIR}/parser/common/thread_pool.cc" From bd6863c11257df2a22b3469ffd58793ef67c6489 Mon Sep 17 00:00:00 2001 From: wjm Date: Mon, 26 Apr 2021 20:49:20 +0800 Subject: [PATCH 2/3] modify --- parser/common/prototype_pass_manager.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/parser/common/prototype_pass_manager.cc b/parser/common/prototype_pass_manager.cc index ceca35c..ce4a099 100644 --- a/parser/common/prototype_pass_manager.cc +++ b/parser/common/prototype_pass_manager.cc @@ -16,6 +16,7 @@ #include "prototype_pass_manager.h" +#include "common/util.h" #include "framework/common/debug/ge_log.h" namespace ge { @@ -25,10 +26,14 @@ ProtoTypePassManager &ProtoTypePassManager::Instance() { } Status ProtoTypePassManager::Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type) { + GE_CHECK_NOTNULL(message); const auto &pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(fmk_type); for (const auto &pass_item : pass_vec) { std::string pass_name = pass_item.first; - std::unique_ptr pass = std::unique_ptr(pass_item.second()); + const auto &func = pass_item.second; + GE_CHECK_NOTNULL(func); + std::unique_ptr pass = std::unique_ptr(func()); + GE_CHECK_NOTNULL(pass); Status ret = pass->Run(message); if (ret != SUCCESS) { GELOGE(FAILED, "Run ProtoType pass:%s failed", pass_name.c_str()); From 5aaa9eb388c70749de094cf48f9ffed21128c3f7 Mon Sep 17 00:00:00 2001 From: wjm Date: Mon, 26 Apr 2021 20:53:38 +0800 Subject: [PATCH 3/3] update submodule --- metadef | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadef b/metadef index d7c67f0..1c41e02 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit d7c67f063b8452462ce59596234aeab6f4f5f7d8 +Subproject commit 1c41e02f73b6e8f95369e052ee4de285145fb34f