| @@ -44,6 +44,7 @@ | |||
| #include "parser/caffe/caffe_op_parser.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "parser/common/pre_checker.h" | |||
| #include "parser/common/prototype_pass_manager.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| #include "parser/common/model_saver.h" | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| @@ -1503,6 +1504,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||
| GE_CHK_BOOL_RET_STATUS((proto_message.layer_size() != 0), FAILED, | |||
| "[Check][Size]net layer num is zero, prototxt file may be invalid."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), | |||
| "Run ProtoType Pass Failed"); | |||
| // Set network name | |||
| GE_IF_BOOL_EXEC((proto_message.has_name()), graph->SetName(proto_message.name())); | |||
| @@ -1710,6 +1713,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11022"); | |||
| return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), | |||
| "Run ProtoType Pass Failed"); | |||
| // Set network name | |||
| GE_IF_BOOL_EXEC((proto_message.has_name() && !proto_message.name().empty()), graph->SetName(proto_message.name())); | |||
| @@ -26,6 +26,7 @@ set(SRC_LIST | |||
| "thread_pool.cc" | |||
| "parser_utils.cc" | |||
| "auto_mapping_subgraph_io_index_func.cc" | |||
| "prototype_pass_manager.cc" | |||
| ) | |||
| ############ libparser_common.so ############ | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "prototype_pass_manager.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| namespace ge { | |||
| ProtoTypePassManager &ProtoTypePassManager::Instance() { | |||
| static ProtoTypePassManager instance; | |||
| return instance; | |||
| } | |||
| Status ProtoTypePassManager::Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type) { | |||
| const auto &pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(fmk_type); | |||
| for (const auto &pass_item : pass_vec) { | |||
| std::string pass_name = pass_item.first; | |||
| std::unique_ptr<ProtoTypeBasePass> pass = std::unique_ptr<ProtoTypeBasePass>(pass_item.second()); | |||
| Status ret = pass->Run(message); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "Run ProtoType pass:%s failed", pass_name.c_str()); | |||
| return ret; | |||
| } | |||
| GELOGD("Run ProtoType pass:%s success", pass_name.c_str()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef PARSER_PROTOTYPE_PASS_MANAGER_H | |||
| #define PARSER_PROTOTYPE_PASS_MANAGER_H | |||
| #include "register/prototype_pass_registry.h" | |||
| namespace ge { | |||
| class ProtoTypePassManager { | |||
| public: | |||
| static ProtoTypePassManager &Instance(); | |||
| Status Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type); | |||
| ~ProtoTypePassManager() = default; | |||
| private: | |||
| ProtoTypePassManager() = default; | |||
| }; | |||
| } // namespace ge | |||
| #endif // PARSER_PROTOTYPE_PASS_MANAGER_H | |||
| @@ -35,6 +35,7 @@ | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #include "parser/common/model_saver.h" | |||
| #include "parser/common/parser_utils.h" | |||
| #include "parser/common/prototype_pass_manager.h" | |||
| #include "parser/onnx/onnx_util.h" | |||
| #include "register/op_registry.h" | |||
| #include "register/register_fmk_types.h" | |||
| @@ -838,6 +839,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
| ClearMembers(); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), | |||
| "Run ProtoType Pass Failed"); | |||
| // 2. Get all inializer. | |||
| std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | |||
| for (int i = 0; i < onnx_graph.initializer_size(); i++) { | |||
| @@ -41,6 +41,7 @@ | |||
| #include "parser/common/parser_fp16_t.h" | |||
| #include "parser/common/pass_manager.h" | |||
| #include "parser/common/pre_checker.h" | |||
| #include "parser/common/prototype_pass_manager.h" | |||
| #include "parser/common/thread_pool.h" | |||
| #include "parser/common/parser_utils.h" | |||
| #include "parser/common/util.h" | |||
| @@ -1146,6 +1147,9 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||
| GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); | |||
| } | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), | |||
| "Run ProtoType Pass Failed"); | |||
| shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | |||
| Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); | |||
| if (ret != SUCCESS) { | |||
| @@ -1374,6 +1378,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||
| // Make a copy for operation without modifying the original graph def. | |||
| domi::tensorflow::GraphDef graph_def = *ori_graph; | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), | |||
| "Run ProtoType Pass Failed"); | |||
| shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | |||
| Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); | |||
| if (ret != SUCCESS) { | |||
| @@ -2216,6 +2223,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||
| domi::tensorflow::GraphDef *graph_def = &graph_def_operation; | |||
| GELOGI("[TF Parser] graph def version:%d", graph_def->version()); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(graph_def, domi::TENSORFLOW), | |||
| "Run ProtoType Pass Failed"); | |||
| shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | |||
| Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph); | |||
| if (ret != SUCCESS) { | |||
| @@ -177,6 +177,7 @@ set(REGISTER_SRC_FILES | |||
| "${PARSER_DIR}/metadef/register/scope/scope_pattern.cc" | |||
| "${PARSER_DIR}/metadef/register/scope/scope_util.cc" | |||
| "${PARSER_DIR}/metadef/register/tensor_assign.cpp" | |||
| "${PARSER_DIR}/metadef/register/prototype_pass_registry.cc" | |||
| ) | |||
| # include directories | |||
| @@ -246,6 +247,7 @@ set(PARSER_SRC_FILES | |||
| "${PARSER_DIR}/parser/common/pass_manager.cc" | |||
| "${PARSER_DIR}/parser/common/pre_checker.cc" | |||
| "${PARSER_DIR}/parser/common/proto_file_parser.cc" | |||
| "${PARSER_DIR}/parser/common/prototype_pass_manager.cc" | |||
| "${PARSER_DIR}/parser/common/register_tbe.cc" | |||
| "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" | |||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | |||