Browse Source

!536 decouple internal header file of parser

Merge pull request !536 from 王涛/ge_dev
pull/538/head
i-robot Gitee 3 years ago
parent
commit
3cdfef85d8
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
19 changed files with 100 additions and 77 deletions
  1. +1
    -1
      metadef
  2. +0
    -1
      parser/caffe/caffe_parser.cc
  3. +24
    -0
      parser/caffe/caffe_parser.h
  4. +1
    -1
      parser/common/auto_mapping_subgraph_io_index_func.cc
  5. +0
    -61
      parser/common/op_types.h
  6. +10
    -0
      parser/common/parser_factory.cc
  7. +1
    -1
      parser/common/pre_checker.cc
  8. +1
    -1
      parser/common/pre_checker.h
  9. +0
    -1
      parser/onnx/onnx_parser.cc
  10. +25
    -0
      parser/onnx/onnx_parser.h
  11. +1
    -1
      parser/tensorflow/graph_optimizer.cc
  12. +0
    -1
      parser/tensorflow/tensorflow_parser.cc
  13. +25
    -0
      parser/tensorflow/tensorflow_parser.h
  14. +1
    -1
      tests/st/testcase/test_caffe_parser.cc
  15. +2
    -1
      tests/st/testcase/test_onnx_parser.cc
  16. +3
    -2
      tests/st/testcase/test_tensorflow_parser.cc
  17. +1
    -1
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc
  18. +2
    -1
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc
  19. +2
    -2
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 7f1f5c49e3802219a1d6c4b874b0b553a7370220
Subproject commit bddfaec360c4a5a64a8fccd5fb30fee521b99304

+ 0
- 1
parser/caffe/caffe_parser.cc View File

@@ -45,7 +45,6 @@
#include "parser/caffe/caffe_custom_parser_adapter.h"
#include "parser/caffe/caffe_op_parser.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/prototype_pass_manager.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/common/model_saver.h"


+ 24
- 0
parser/caffe/caffe_parser.h View File

@@ -40,6 +40,7 @@
#include "omg/parser/op_parser.h"
#include "omg/parser/model_parser.h"
#include "omg/parser/weights_parser.h"
#include "common/pre_checker.h"
#include "proto/caffe/caffe.pb.h"
#include "proto/om.pb.h"

@@ -123,6 +124,17 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
return domi::SUCCESS;
}

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
private:
Status Parse(const char *model_path, ge::ComputeGraphPtr &graph);

@@ -346,6 +358,18 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {

Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}

private:
Status CheckNodes(ge::ComputeGraphPtr &graph);
/**


+ 1
- 1
parser/common/auto_mapping_subgraph_io_index_func.cc View File

@@ -21,11 +21,11 @@
#include "graph/op_desc.h"
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_util.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "register/register_fmk_types.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"

namespace ge {
namespace {


+ 0
- 61
parser/common/op_types.h View File

@@ -1,61 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_OP_TYPES_H_
#define PARSER_COMMON_OP_TYPES_H_

#include <set>
#include <string>

namespace ge {
class GE_FUNC_VISIBILITY OpTypeContainer {
public:
static OpTypeContainer *Instance() {
static OpTypeContainer instance;
return &instance;
}
~OpTypeContainer() = default;

void Register(const std::string &op_type) { op_type_list_.insert(op_type); }

bool IsExisting(const std::string &op_type) {
return op_type_list_.count(op_type) > 0UL;
}

protected:
OpTypeContainer() {}

private:
std::set<std::string> op_type_list_;
};

class GE_FUNC_VISIBILITY OpTypeRegistrar {
public:
explicit OpTypeRegistrar(const std::string &op_type) { OpTypeContainer::Instance()->Register(op_type); }
~OpTypeRegistrar() {}
};

#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *var_name;

#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \
const char *var_name = str_name; \
const OpTypeRegistrar g_##var_name##_reg(str_name);

#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name))
} // namespace ge

#endif // PARSER_COMMON_OP_TYPES_H_

+ 10
- 0
parser/common/parser_factory.cc View File

@@ -16,6 +16,7 @@

#include "omg/parser/parser_factory.h"
#include "framework/common/debug/ge_log.h"
#include "common/register_tbe.h"

namespace domi {
FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() {
@@ -77,4 +78,13 @@ FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::Fr
ModelParserFactory::~ModelParserFactory() {
creator_map_.clear();
}

FMK_FUNC_HOST_VISIBILITY OpRegTbeParserFactory *OpRegTbeParserFactory::Instance() {
static OpRegTbeParserFactory instance;
return &instance;
}

void OpRegTbeParserFactory::Finalize(const domi::OpRegistrationData &reg_data) {
(void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data);
}
} // namespace domi

+ 1
- 1
parser/common/pre_checker.cc View File

@@ -200,7 +200,7 @@ FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() {
return false;
}

Status PreChecker::Save(string file) {
Status PreChecker::Save(const string &file) {
uint32_t fail_num = 0;
for (auto id : ops_) {
if (HasError(id)) {


+ 1
- 1
parser/common/pre_checker.h View File

@@ -142,7 +142,7 @@ class PreChecker {
* @ingroup domi_omg
* @brief Save inspection results(JSON)
*/
Status Save(string file);
Status Save(const string &file);

private:
/**


+ 0
- 1
parser/onnx/onnx_parser.cc View File

@@ -32,7 +32,6 @@
#include "onnx_op_parser.h"
#include "onnx_util.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/acl_graph_parser_util.h"
#include "parser/common/model_saver.h"
#include "parser/common/parser_utils.h"


+ 25
- 0
parser/onnx/onnx_parser.h View File

@@ -38,6 +38,7 @@
#include "omg/parser/op_parser.h"
#include "omg/parser/weights_parser.h"
#include "common/parser_utils.h"
#include "common/pre_checker.h"
#include "proto/onnx/ge_onnx.pb.h"

namespace ge {
@@ -81,6 +82,18 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
return domi::SUCCESS;
}

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}

private:
Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph);

@@ -161,6 +174,18 @@ class PARSER_FUNC_VISIBILITY OnnxWeightsParser : public domi::WeightsParser {
(void)graph;
return domi::SUCCESS;
}

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
};
} // namespace domi
#endif // PARSER_ONNX_ONNX_PARSER_H_

+ 1
- 1
parser/tensorflow/graph_optimizer.cc View File

@@ -15,7 +15,7 @@
*/

#include "graph_optimizer.h"
#include "common/op_types.h"
#include "graph/op_types.h"
#include "common/types_map.h"
#include "common/util.h"
#include "framework/omg/parser/parser_inner_ctx.h"


+ 0
- 1
parser/tensorflow/tensorflow_parser.cc View File

@@ -40,7 +40,6 @@
#include "parser/common/op_parser_factory.h"
#include "parser/common/parser_fp16_t.h"
#include "parser/common/pass_manager.h"
#include "parser/common/pre_checker.h"
#include "parser/common/prototype_pass_manager.h"
#include "parser/common/thread_pool.h"
#include "parser/common/parser_utils.h"


+ 25
- 0
parser/tensorflow/tensorflow_parser.h View File

@@ -35,6 +35,7 @@
#include "omg/parser/model_parser.h"
#include "omg/parser/op_parser.h"
#include "omg/parser/weights_parser.h"
#include "common/pre_checker.h"
#include "parser/tensorflow/tensorflow_fusion_op_parser.h"
#include "parser/tensorflow/tensorflow_util.h"
#include "proto/om.pb.h"
@@ -154,6 +155,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
*/
Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback,
ge::ComputeGraphPtr &root_graph) override;

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
private:
Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph);

@@ -686,6 +699,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowWeightsParser : public domi::WeightsParse
Status Parse(const char *file, ge::Graph &graph) override;

Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
};
} // namespace domi
#endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_

+ 1
- 1
tests/st/testcase/test_caffe_parser.cc View File

@@ -174,7 +174,7 @@ void STestCaffeParser::RegisterCustomOp() {

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();


+ 2
- 1
tests/st/testcase/test_onnx_parser.cc View File

@@ -24,6 +24,7 @@
#include "st/parser_st_utils.h"
#include "external/ge/ge_api_types.h"
#include "tests/depends/ops_stub/ops_stub.h"
#include "framework/omg/parser/parser_factory.h"
#include "parser/onnx/onnx_parser.h"

namespace ge {
@@ -96,7 +97,7 @@ void STestOnnxParser::RegisterCustomOp() {

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();


+ 3
- 2
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -64,6 +64,7 @@
#include "parser/common/data_op_parser.h"
#include "parser/common/model_saver.h"
#include "framework/omg/parser/parser_api.h"
#include "framework/omg/parser/parser_factory.h"
#include "parser/common/parser_fp16_t.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/prototype_pass_manager.h"
@@ -151,7 +152,7 @@ void STestTensorflowParser::RegisterCustomOp() {
.ParseParamsFn(ParseParams);
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
@@ -584,7 +585,7 @@ namespace {
void register_tbe_op() {
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
for (OpRegistrationData reg_data : registrationDatas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
OpRegistry::Instance()->Register(reg_data);
}
OpRegistry::Instance()->registrationDatas.clear();


+ 1
- 1
tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc View File

@@ -163,7 +163,7 @@ static ge::NodePtr GenNodeFromOpDesc(ge::OpDescPtr opDesc){
void UtestCaffeParser::RegisterCustomOp() {
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();


+ 2
- 1
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -24,6 +24,7 @@
#include "external/parser/onnx_parser.h"
#include "ut/parser/parser_ut_utils.h"
#include "external/ge/ge_api_types.h"
#include "framework/omg/parser/parser_factory.h"
#include "tests/depends/ops_stub/ops_stub.h"

#define protected public
@@ -103,7 +104,7 @@ void UtestOnnxParser::RegisterCustomOp() {
.ParseParamsFn(ParseParams);
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();


+ 2
- 2
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -176,7 +176,7 @@ void UtestTensorflowParser::RegisterCustomOp() {
.ParseParamsFn(ParseParams);
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
@@ -599,7 +599,7 @@ namespace {
void register_tbe_op() {
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
for (OpRegistrationData reg_data : registrationDatas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
OpRegistry::Instance()->Register(reg_data);
}
OpRegistry::Instance()->registrationDatas.clear();


Loading…
Cancel
Save