| @@ -44,6 +44,7 @@ | |||||
| #include "parser/caffe/caffe_op_parser.h" | #include "parser/caffe/caffe_op_parser.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/common/pre_checker.h" | #include "parser/common/pre_checker.h" | ||||
| #include "parser/common/prototype_pass_manager.h" | |||||
| #include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
| #include "parser/common/model_saver.h" | #include "parser/common/model_saver.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
| @@ -1503,6 +1504,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||||
| GE_CHK_BOOL_RET_STATUS((proto_message.layer_size() != 0), FAILED, | GE_CHK_BOOL_RET_STATUS((proto_message.layer_size() != 0), FAILED, | ||||
| "[Check][Size]net layer num is zero, prototxt file may be invalid."); | "[Check][Size]net layer num is zero, prototxt file may be invalid."); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), | |||||
| "Run ProtoType Pass Failed"); | |||||
| // Set network name | // Set network name | ||||
| GE_IF_BOOL_EXEC((proto_message.has_name()), graph->SetName(proto_message.name())); | GE_IF_BOOL_EXEC((proto_message.has_name()), graph->SetName(proto_message.name())); | ||||
| @@ -1710,6 +1713,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E11022"); | ErrorManager::GetInstance().ATCReportErrMessage("E11022"); | ||||
| return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); | return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), | |||||
| "Run ProtoType Pass Failed"); | |||||
| // Set network name | // Set network name | ||||
| GE_IF_BOOL_EXEC((proto_message.has_name() && !proto_message.name().empty()), graph->SetName(proto_message.name())); | GE_IF_BOOL_EXEC((proto_message.has_name() && !proto_message.name().empty()), graph->SetName(proto_message.name())); | ||||
| @@ -26,6 +26,7 @@ set(SRC_LIST | |||||
| "thread_pool.cc" | "thread_pool.cc" | ||||
| "parser_utils.cc" | "parser_utils.cc" | ||||
| "auto_mapping_subgraph_io_index_func.cc" | "auto_mapping_subgraph_io_index_func.cc" | ||||
| "prototype_pass_manager.cc" | |||||
| ) | ) | ||||
| ############ libparser_common.so ############ | ############ libparser_common.so ############ | ||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "prototype_pass_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| ProtoTypePassManager &ProtoTypePassManager::Instance() { | |||||
| static ProtoTypePassManager instance; | |||||
| return instance; | |||||
| } | |||||
| Status ProtoTypePassManager::Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type) { | |||||
| const auto &pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(fmk_type); | |||||
| for (const auto &pass_item : pass_vec) { | |||||
| std::string pass_name = pass_item.first; | |||||
| std::unique_ptr<ProtoTypeBasePass> pass = std::unique_ptr<ProtoTypeBasePass>(pass_item.second()); | |||||
| Status ret = pass->Run(message); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Run ProtoType pass:%s failed", pass_name.c_str()); | |||||
| return ret; | |||||
| } | |||||
| GELOGD("Run ProtoType pass:%s success", pass_name.c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_PROTOTYPE_PASS_MANAGER_H | |||||
| #define PARSER_PROTOTYPE_PASS_MANAGER_H | |||||
| #include "register/prototype_pass_registry.h" | |||||
| namespace ge { | |||||
| class ProtoTypePassManager { | |||||
| public: | |||||
| static ProtoTypePassManager &Instance(); | |||||
| Status Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type); | |||||
| ~ProtoTypePassManager() = default; | |||||
| private: | |||||
| ProtoTypePassManager() = default; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_PROTOTYPE_PASS_MANAGER_H | |||||
| @@ -35,6 +35,7 @@ | |||||
| #include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
| #include "parser/common/model_saver.h" | #include "parser/common/model_saver.h" | ||||
| #include "parser/common/parser_utils.h" | #include "parser/common/parser_utils.h" | ||||
| #include "parser/common/prototype_pass_manager.h" | |||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
| @@ -838,6 +839,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||||
| ClearMembers(); | ClearMembers(); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), | |||||
| "Run ProtoType Pass Failed"); | |||||
| // 2. Get all inializer. | // 2. Get all inializer. | ||||
| std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | ||||
| for (int i = 0; i < onnx_graph.initializer_size(); i++) { | for (int i = 0; i < onnx_graph.initializer_size(); i++) { | ||||
| @@ -41,6 +41,7 @@ | |||||
| #include "parser/common/parser_fp16_t.h" | #include "parser/common/parser_fp16_t.h" | ||||
| #include "parser/common/pass_manager.h" | #include "parser/common/pass_manager.h" | ||||
| #include "parser/common/pre_checker.h" | #include "parser/common/pre_checker.h" | ||||
| #include "parser/common/prototype_pass_manager.h" | |||||
| #include "parser/common/thread_pool.h" | #include "parser/common/thread_pool.h" | ||||
| #include "parser/common/parser_utils.h" | #include "parser/common/parser_utils.h" | ||||
| #include "parser/common/util.h" | #include "parser/common/util.h" | ||||
| @@ -1146,6 +1147,9 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||||
| GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); | GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); | ||||
| } | } | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), | |||||
| "Run ProtoType Pass Failed"); | |||||
| shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | ||||
| Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); | Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -1374,6 +1378,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||||
| // Make a copy for operation without modifying the original graph def. | // Make a copy for operation without modifying the original graph def. | ||||
| domi::tensorflow::GraphDef graph_def = *ori_graph; | domi::tensorflow::GraphDef graph_def = *ori_graph; | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW), | |||||
| "Run ProtoType Pass Failed"); | |||||
| shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | ||||
| Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); | Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -2216,6 +2223,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||||
| domi::tensorflow::GraphDef *graph_def = &graph_def_operation; | domi::tensorflow::GraphDef *graph_def = &graph_def_operation; | ||||
| GELOGI("[TF Parser] graph def version:%d", graph_def->version()); | GELOGI("[TF Parser] graph def version:%d", graph_def->version()); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(graph_def, domi::TENSORFLOW), | |||||
| "Run ProtoType Pass Failed"); | |||||
| shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | shared_ptr<ge::ScopeGraph> scope_graph = nullptr; | ||||
| Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph); | Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -177,6 +177,7 @@ set(REGISTER_SRC_FILES | |||||
| "${PARSER_DIR}/metadef/register/scope/scope_pattern.cc" | "${PARSER_DIR}/metadef/register/scope/scope_pattern.cc" | ||||
| "${PARSER_DIR}/metadef/register/scope/scope_util.cc" | "${PARSER_DIR}/metadef/register/scope/scope_util.cc" | ||||
| "${PARSER_DIR}/metadef/register/tensor_assign.cpp" | "${PARSER_DIR}/metadef/register/tensor_assign.cpp" | ||||
| "${PARSER_DIR}/metadef/register/prototype_pass_registry.cc" | |||||
| ) | ) | ||||
| # include directories | # include directories | ||||
| @@ -246,6 +247,7 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/common/pass_manager.cc" | "${PARSER_DIR}/parser/common/pass_manager.cc" | ||||
| "${PARSER_DIR}/parser/common/pre_checker.cc" | "${PARSER_DIR}/parser/common/pre_checker.cc" | ||||
| "${PARSER_DIR}/parser/common/proto_file_parser.cc" | "${PARSER_DIR}/parser/common/proto_file_parser.cc" | ||||
| "${PARSER_DIR}/parser/common/prototype_pass_manager.cc" | |||||
| "${PARSER_DIR}/parser/common/register_tbe.cc" | "${PARSER_DIR}/parser/common/register_tbe.cc" | ||||
| "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" | "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" | ||||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||