Browse Source

run prototype pass

pull/297/head
wjm 4 years ago
parent
commit
a4e0b609a6
7 changed files with 98 additions and 0 deletions
  1. +5
    -0
      parser/caffe/caffe_parser.cc
  2. +1
    -0
      parser/common/CMakeLists.txt
  3. +41
    -0
      parser/common/prototype_pass_manager.cc
  4. +36
    -0
      parser/common/prototype_pass_manager.h
  5. +3
    -0
      parser/onnx/onnx_parser.cc
  6. +10
    -0
      parser/tensorflow/tensorflow_parser.cc
  7. +2
    -0
      tests/ut/parser/CMakeLists.txt

+ 5
- 0
parser/caffe/caffe_parser.cc View File

@@ -44,6 +44,7 @@
#include "parser/caffe/caffe_op_parser.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/prototype_pass_manager.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/common/model_saver.h"
#include "parser/common/acl_graph_parser_util.h"
@@ -1503,6 +1504,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
GE_CHK_BOOL_RET_STATUS((proto_message.layer_size() != 0), FAILED,
"[Check][Size]net layer num is zero, prototxt file may be invalid.");

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE),
"Run ProtoType Pass Failed");
// Set network name
GE_IF_BOOL_EXEC((proto_message.has_name()), graph->SetName(proto_message.name()));

@@ -1710,6 +1713,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
ErrorManager::GetInstance().ATCReportErrMessage("E11022");
return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid.");

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE),
"Run ProtoType Pass Failed");
// Set network name
GE_IF_BOOL_EXEC((proto_message.has_name() && !proto_message.name().empty()), graph->SetName(proto_message.name()));



+ 1
- 0
parser/common/CMakeLists.txt View File

@@ -26,6 +26,7 @@ set(SRC_LIST
"thread_pool.cc"
"parser_utils.cc"
"auto_mapping_subgraph_io_index_func.cc"
"prototype_pass_manager.cc"
)

############ libparser_common.so ############


+ 41
- 0
parser/common/prototype_pass_manager.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "prototype_pass_manager.h"

#include "framework/common/debug/ge_log.h"

namespace ge {
ProtoTypePassManager &ProtoTypePassManager::Instance() {
static ProtoTypePassManager instance;
return instance;
}

Status ProtoTypePassManager::Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type) {
const auto &pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(fmk_type);
for (const auto &pass_item : pass_vec) {
std::string pass_name = pass_item.first;
std::unique_ptr<ProtoTypeBasePass> pass = std::unique_ptr<ProtoTypeBasePass>(pass_item.second());
Status ret = pass->Run(message);
if (ret != SUCCESS) {
GELOGE(FAILED, "Run ProtoType pass:%s failed", pass_name.c_str());
return ret;
}
GELOGD("Run ProtoType pass:%s success", pass_name.c_str());
}
return SUCCESS;
}
} // namespace ge

+ 36
- 0
parser/common/prototype_pass_manager.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_PROTOTYPE_PASS_MANAGER_H
#define PARSER_PROTOTYPE_PASS_MANAGER_H

#include "register/prototype_pass_registry.h"

namespace ge {
class ProtoTypePassManager {
public:
static ProtoTypePassManager &Instance();

Status Run(google::protobuf::Message *message, const domi::FrameworkType &fmk_type);

~ProtoTypePassManager() = default;

private:
ProtoTypePassManager() = default;
};
} // namespace ge

#endif // PARSER_PROTOTYPE_PASS_MANAGER_H

+ 3
- 0
parser/onnx/onnx_parser.cc View File

@@ -35,6 +35,7 @@
#include "parser/common/acl_graph_parser_util.h"
#include "parser/common/model_saver.h"
#include "parser/common/parser_utils.h"
#include "parser/common/prototype_pass_manager.h"
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"
#include "register/register_fmk_types.h"
@@ -838,6 +839,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP

ClearMembers();

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX),
"Run ProtoType Pass Failed");
// 2. Get all inializer.
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor;
for (int i = 0; i < onnx_graph.initializer_size(); i++) {


+ 10
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -41,6 +41,7 @@
#include "parser/common/parser_fp16_t.h"
#include "parser/common/pass_manager.h"
#include "parser/common/pre_checker.h"
#include "parser/common/prototype_pass_manager.h"
#include "parser/common/thread_pool.h"
#include "parser/common/parser_utils.h"
#include "parser/common/util.h"
@@ -1146,6 +1147,9 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size());
}

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW),
"Run ProtoType Pass Failed");

shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph);
if (ret != SUCCESS) {
@@ -1374,6 +1378,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
// Make a copy for operation without modifying the original graph def.
domi::tensorflow::GraphDef graph_def = *ori_graph;

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW),
"Run ProtoType Pass Failed");

shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph);
if (ret != SUCCESS) {
@@ -2216,6 +2223,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
domi::tensorflow::GraphDef *graph_def = &graph_def_operation;
GELOGI("[TF Parser] graph def version:%d", graph_def->version());

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(graph_def, domi::TENSORFLOW),
"Run ProtoType Pass Failed");

shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph);
if (ret != SUCCESS) {


+ 2
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -177,6 +177,7 @@ set(REGISTER_SRC_FILES
"${PARSER_DIR}/metadef/register/scope/scope_pattern.cc"
"${PARSER_DIR}/metadef/register/scope/scope_util.cc"
"${PARSER_DIR}/metadef/register/tensor_assign.cpp"
"${PARSER_DIR}/metadef/register/prototype_pass_registry.cc"
)

# include directories
@@ -246,6 +247,7 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/pass_manager.cc"
"${PARSER_DIR}/parser/common/pre_checker.cc"
"${PARSER_DIR}/parser/common/proto_file_parser.cc"
"${PARSER_DIR}/parser/common/prototype_pass_manager.cc"
"${PARSER_DIR}/parser/common/register_tbe.cc"
"${PARSER_DIR}/parser/common/tbe_plugin_loader.cc"
"${PARSER_DIR}/parser/common/thread_pool.cc"


Loading…
Cancel
Save