diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index fe0a725..7010fae 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -44,6 +44,7 @@ #include "parser/caffe/caffe_op_parser.h" #include "parser/common/op_parser_factory.h" #include "parser/common/pre_checker.h" +#include "parser/common/prototype_pass_manager.h" #include "framework/omg/parser/parser_types.h" #include "parser/common/model_saver.h" #include "parser/common/acl_graph_parser_util.h" @@ -1503,6 +1504,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co GE_CHK_BOOL_RET_STATUS((proto_message.layer_size() != 0), FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), + "Run ProtoType Pass Failed"); // Set network name GE_IF_BOOL_EXEC((proto_message.has_name()), graph->SetName(proto_message.name())); @@ -1710,6 +1713,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap ErrorManager::GetInstance().ATCReportErrMessage("E11022"); return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), + "Run ProtoType Pass Failed"); // Set network name GE_IF_BOOL_EXEC((proto_message.has_name() && !proto_message.name().empty()), graph->SetName(proto_message.name())); diff --git a/parser/common/CMakeLists.txt b/parser/common/CMakeLists.txt index 6dc8b6d..9a45d34 100644 --- a/parser/common/CMakeLists.txt +++ b/parser/common/CMakeLists.txt @@ -26,6 +26,7 @@ set(SRC_LIST "thread_pool.cc" "parser_utils.cc" "auto_mapping_subgraph_io_index_func.cc" + "prototype_pass_manager.cc" ) ############ libparser_common.so ############ diff --git a/parser/common/prototype_pass_manager.cc b/parser/common/prototype_pass_manager.cc new file mode 100644 index 0000000..ceca35c --- /dev/null +++ b/parser/common/prototype_pass_manager.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "prototype_pass_manager.h" + +#include "framework/common/debug/ge_log.h" + +namespace ge { +ProtoTypePassManager &ProtoTypePassManager::Instance() { + static ProtoTypePassManager instance; + return instance; +} + +Status ProtoTypePassManager::Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type) { + const auto &pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(fmk_type); + for (const auto &pass_item : pass_vec) { + std::string pass_name = pass_item.first; + std::unique_ptr pass = std::unique_ptr(pass_item.second()); + Status ret = pass->Run(message); + if (ret != SUCCESS) { + GELOGE(FAILED, "Run ProtoType pass:%s failed", pass_name.c_str()); + return ret; + } + GELOGD("Run ProtoType pass:%s success", pass_name.c_str()); + } + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/parser/common/prototype_pass_manager.h b/parser/common/prototype_pass_manager.h new file mode 100644 index 0000000..d9c2f9f --- /dev/null +++ b/parser/common/prototype_pass_manager.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_PROTOTYPE_PASS_MANAGER_H +#define PARSER_PROTOTYPE_PASS_MANAGER_H + +#include "register/prototype_pass_registry.h" + +namespace ge { +class ProtoTypePassManager { + public: + static ProtoTypePassManager &Instance(); + + Status Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type); + + ~ProtoTypePassManager() = default; + + private: + ProtoTypePassManager() = default; +}; +} // namespace ge + +#endif // PARSER_PROTOTYPE_PASS_MANAGER_H diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 2de88c5..abddc76 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -35,6 +35,7 @@ #include "parser/common/acl_graph_parser_util.h" #include "parser/common/model_saver.h" #include "parser/common/parser_utils.h" +#include "parser/common/prototype_pass_manager.h" #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" #include "register/register_fmk_types.h" @@ -838,6 +839,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP ClearMembers(); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), + "Run ProtoType Pass Failed"); // 2. Get all inializer. std::map initializer_name_tensor; for (int i = 0; i < onnx_graph.initializer_size(); i++) { diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 4e7c59f..74ce961 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -41,6 +41,7 @@ #include "parser/common/parser_fp16_t.h" #include "parser/common/pass_manager.h" #include "parser/common/pre_checker.h" +#include "parser/common/prototype_pass_manager.h" #include "parser/common/thread_pool.h" #include "parser/common/parser_utils.h" #include "parser/common/util.h" @@ -1146,6 +1147,9 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); } + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), + "Run ProtoType Pass Failed"); + shared_ptr scope_graph = nullptr; Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); if (ret != SUCCESS) { @@ -1374,6 +1378,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro // Make a copy for operation without modifying the original graph def. domi::tensorflow::GraphDef graph_def = *ori_graph; + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), + "Run ProtoType Pass Failed"); + shared_ptr scope_graph = nullptr; Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); if (ret != SUCCESS) { @@ -2216,6 +2223,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, domi::tensorflow::GraphDef *graph_def = &graph_def_operation; GELOGI("[TF Parser] graph def version:%d", graph_def->version()); + GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(graph_def, domi::TENSORFLOW), + "Run ProtoType Pass Failed"); + shared_ptr scope_graph = nullptr; Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph); if (ret != SUCCESS) { diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 9f0578d..7130e20 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -177,6 +177,7 @@ set(REGISTER_SRC_FILES "${PARSER_DIR}/metadef/register/scope/scope_pattern.cc" "${PARSER_DIR}/metadef/register/scope/scope_util.cc" "${PARSER_DIR}/metadef/register/tensor_assign.cpp" + "${PARSER_DIR}/metadef/register/prototype_pass_registry.cc" ) # include directories @@ -246,6 +247,7 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/common/pass_manager.cc" "${PARSER_DIR}/parser/common/pre_checker.cc" "${PARSER_DIR}/parser/common/proto_file_parser.cc" + "${PARSER_DIR}/parser/common/prototype_pass_manager.cc" "${PARSER_DIR}/parser/common/register_tbe.cc" "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" "${PARSER_DIR}/parser/common/thread_pool.cc"