diff --git a/metadef b/metadef index 7f1f5c4..bddfaec 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 7f1f5c49e3802219a1d6c4b874b0b553a7370220 +Subproject commit bddfaec360c4a5a64a8fccd5fb30fee521b99304 diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 4d4a235..500247c 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -45,7 +45,6 @@ #include "parser/caffe/caffe_custom_parser_adapter.h" #include "parser/caffe/caffe_op_parser.h" #include "parser/common/op_parser_factory.h" -#include "parser/common/pre_checker.h" #include "parser/common/prototype_pass_manager.h" #include "framework/omg/parser/parser_types.h" #include "parser/common/model_saver.h" diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h index fcfc294..8b08b9a 100644 --- a/parser/caffe/caffe_parser.h +++ b/parser/caffe/caffe_parser.h @@ -40,6 +40,7 @@ #include "omg/parser/op_parser.h" #include "omg/parser/model_parser.h" #include "omg/parser/weights_parser.h" +#include "common/pre_checker.h" #include "proto/caffe/caffe.pb.h" #include "proto/om.pb.h" @@ -123,6 +124,17 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { return domi::SUCCESS; } + bool HasError() override { + return PreChecker::Instance().HasError(); + } + + Status Save(const string &file) override { + return PreChecker::Instance().Save(file); + } + + void Clear() override { + PreChecker::Instance().Clear(); + } private: Status Parse(const char *model_path, ge::ComputeGraphPtr &graph); @@ -346,6 +358,18 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; + bool HasError() override { + return PreChecker::Instance().HasError(); + } + + Status Save(const string &file) override { + return PreChecker::Instance().Save(file); + } + + void Clear() override { + PreChecker::Instance().Clear(); + } + private: Status CheckNodes(ge::ComputeGraphPtr &graph); /** diff --git a/parser/common/auto_mapping_subgraph_io_index_func.cc b/parser/common/auto_mapping_subgraph_io_index_func.cc index 963c5c0..dafc55d 100644 --- a/parser/common/auto_mapping_subgraph_io_index_func.cc +++ b/parser/common/auto_mapping_subgraph_io_index_func.cc @@ -21,11 +21,11 @@ #include "graph/op_desc.h" #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_util.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "register/register_fmk_types.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" namespace ge { namespace { diff --git a/parser/common/op_types.h b/parser/common/op_types.h deleted file mode 100644 index 6e068fd..0000000 --- a/parser/common/op_types.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARSER_COMMON_OP_TYPES_H_ -#define PARSER_COMMON_OP_TYPES_H_ - -#include -#include - -namespace ge { -class GE_FUNC_VISIBILITY OpTypeContainer { - public: - static OpTypeContainer *Instance() { - static OpTypeContainer instance; - return &instance; - } - ~OpTypeContainer() = default; - - void Register(const std::string &op_type) { op_type_list_.insert(op_type); } - - bool IsExisting(const std::string &op_type) { - return op_type_list_.count(op_type) > 0UL; - } - - protected: - OpTypeContainer() {} - - private: - std::set op_type_list_; -}; - -class GE_FUNC_VISIBILITY OpTypeRegistrar { - public: - explicit OpTypeRegistrar(const std::string &op_type) { OpTypeContainer::Instance()->Register(op_type); } - ~OpTypeRegistrar() {} -}; - -#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *var_name; - -#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \ - const char *var_name = str_name; \ - const OpTypeRegistrar g_##var_name##_reg(str_name); - -#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) -} // namespace ge - -#endif // PARSER_COMMON_OP_TYPES_H_ diff --git a/parser/common/parser_factory.cc b/parser/common/parser_factory.cc index f1cb202..60efe29 100644 --- a/parser/common/parser_factory.cc +++ b/parser/common/parser_factory.cc @@ -16,6 +16,7 @@ #include "omg/parser/parser_factory.h" #include "framework/common/debug/ge_log.h" +#include "common/register_tbe.h" namespace domi { FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() { @@ -77,4 +78,13 @@ FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::Fr ModelParserFactory::~ModelParserFactory() { creator_map_.clear(); } + +FMK_FUNC_HOST_VISIBILITY OpRegTbeParserFactory *OpRegTbeParserFactory::Instance() { + static OpRegTbeParserFactory instance; + return &instance; +} + +void OpRegTbeParserFactory::Finalize(const domi::OpRegistrationData ®_data) { + (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); +} } // namespace domi diff --git a/parser/common/pre_checker.cc b/parser/common/pre_checker.cc index 5a9bbe1..af3ac52 100644 --- a/parser/common/pre_checker.cc +++ b/parser/common/pre_checker.cc @@ -200,7 +200,7 @@ FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() { return false; } -Status PreChecker::Save(string file) { +Status PreChecker::Save(const string &file) { uint32_t fail_num = 0; for (auto id : ops_) { if (HasError(id)) { diff --git a/parser/common/pre_checker.h b/parser/common/pre_checker.h index cdf35ad..12a0327 100644 --- a/parser/common/pre_checker.h +++ b/parser/common/pre_checker.h @@ -142,7 +142,7 @@ class PreChecker { * @ingroup domi_omg * @brief Save inspection results(JSON) */ - Status Save(string file); + Status Save(const string &file); private: /** diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 56fc893..7034978 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -32,7 +32,6 @@ #include "onnx_op_parser.h" #include "onnx_util.h" #include "parser/common/op_parser_factory.h" -#include "parser/common/pre_checker.h" #include "parser/common/acl_graph_parser_util.h" #include "parser/common/model_saver.h" #include "parser/common/parser_utils.h" diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index 899a8d0..dfcf344 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -38,6 +38,7 @@ #include "omg/parser/op_parser.h" #include "omg/parser/weights_parser.h" #include "common/parser_utils.h" +#include "common/pre_checker.h" #include "proto/onnx/ge_onnx.pb.h" namespace ge { @@ -81,6 +82,18 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { return domi::SUCCESS; } + bool HasError() override { + return PreChecker::Instance().HasError(); + } + + Status Save(const string &file) override { + return PreChecker::Instance().Save(file); + } + + void Clear() override { + PreChecker::Instance().Clear(); + } + private: Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); @@ -161,6 +174,18 @@ class PARSER_FUNC_VISIBILITY OnnxWeightsParser : public domi::WeightsParser { (void)graph; return domi::SUCCESS; } + + bool HasError() override { + return PreChecker::Instance().HasError(); + } + + Status Save(const string &file) override { + return PreChecker::Instance().Save(file); + } + + void Clear() override { + PreChecker::Instance().Clear(); + } }; } // namespace domi #endif // PARSER_ONNX_ONNX_PARSER_H_ diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index 500420c..4ba272d 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -15,7 +15,7 @@ */ #include "graph_optimizer.h" -#include "common/op_types.h" +#include "graph/op_types.h" #include "common/types_map.h" #include "common/util.h" #include "framework/omg/parser/parser_inner_ctx.h" diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 5c6dd04..c069a07 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -40,7 +40,6 @@ #include "parser/common/op_parser_factory.h" #include "parser/common/parser_fp16_t.h" #include "parser/common/pass_manager.h" -#include "parser/common/pre_checker.h" #include "parser/common/prototype_pass_manager.h" #include "parser/common/thread_pool.h" #include "parser/common/parser_utils.h" diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index f7e9c39..062d51d 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -35,6 +35,7 @@ #include "omg/parser/model_parser.h" #include "omg/parser/op_parser.h" #include "omg/parser/weights_parser.h" +#include "common/pre_checker.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include "parser/tensorflow/tensorflow_util.h" #include "proto/om.pb.h" @@ -154,6 +155,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { */ Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, ge::ComputeGraphPtr &root_graph) override; + + bool HasError() override { + return PreChecker::Instance().HasError(); + } + + Status Save(const string &file) override { + return PreChecker::Instance().Save(file); + } + + void Clear() override { + PreChecker::Instance().Clear(); + } private: Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph); @@ -686,6 +699,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowWeightsParser : public domi::WeightsParse Status Parse(const char *file, ge::Graph &graph) override; Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; + + bool HasError() override { + return PreChecker::Instance().HasError(); + } + + Status Save(const string &file) override { + return PreChecker::Instance().Save(file); + } + + void Clear() override { + PreChecker::Instance().Clear(); + } }; } // namespace domi #endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ diff --git a/tests/st/testcase/test_caffe_parser.cc b/tests/st/testcase/test_caffe_parser.cc index b57ad90..0d460f3 100644 --- a/tests/st/testcase/test_caffe_parser.cc +++ b/tests/st/testcase/test_caffe_parser.cc @@ -174,7 +174,7 @@ void STestCaffeParser::RegisterCustomOp() { std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; for (auto reg_data : reg_datas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); domi::OpRegistry::Instance()->Register(reg_data); } domi::OpRegistry::Instance()->registrationDatas.clear(); diff --git a/tests/st/testcase/test_onnx_parser.cc b/tests/st/testcase/test_onnx_parser.cc index d863be7..459abed 100644 --- a/tests/st/testcase/test_onnx_parser.cc +++ b/tests/st/testcase/test_onnx_parser.cc @@ -24,6 +24,7 @@ #include "st/parser_st_utils.h" #include "external/ge/ge_api_types.h" #include "tests/depends/ops_stub/ops_stub.h" +#include "framework/omg/parser/parser_factory.h" #include "parser/onnx/onnx_parser.h" namespace ge { @@ -96,7 +97,7 @@ void STestOnnxParser::RegisterCustomOp() { std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; for (auto reg_data : reg_datas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); domi::OpRegistry::Instance()->Register(reg_data); } domi::OpRegistry::Instance()->registrationDatas.clear(); diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 243502f..968658d 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -64,6 +64,7 @@ #include "parser/common/data_op_parser.h" #include "parser/common/model_saver.h" #include "framework/omg/parser/parser_api.h" +#include "framework/omg/parser/parser_factory.h" #include "parser/common/parser_fp16_t.h" #include "parser/common/op_parser_factory.h" #include "parser/common/prototype_pass_manager.h" @@ -151,7 +152,7 @@ void STestTensorflowParser::RegisterCustomOp() { .ParseParamsFn(ParseParams); std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; for (auto reg_data : reg_datas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); domi::OpRegistry::Instance()->Register(reg_data); } domi::OpRegistry::Instance()->registrationDatas.clear(); @@ -584,7 +585,7 @@ namespace { void register_tbe_op() { std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; for (OpRegistrationData reg_data : registrationDatas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); OpRegistry::Instance()->Register(reg_data); } OpRegistry::Instance()->registrationDatas.clear(); diff --git a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc index 4649e3d..846a209 100755 --- a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc +++ b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc @@ -163,7 +163,7 @@ static ge::NodePtr GenNodeFromOpDesc(ge::OpDescPtr opDesc){ void UtestCaffeParser::RegisterCustomOp() { std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; for (auto reg_data : reg_datas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); domi::OpRegistry::Instance()->Register(reg_data); } domi::OpRegistry::Instance()->registrationDatas.clear(); diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index 0b50fd3..7ffbc04 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -24,6 +24,7 @@ #include "external/parser/onnx_parser.h" #include "ut/parser/parser_ut_utils.h" #include "external/ge/ge_api_types.h" +#include "framework/omg/parser/parser_factory.h" #include "tests/depends/ops_stub/ops_stub.h" #define protected public @@ -103,7 +104,7 @@ void UtestOnnxParser::RegisterCustomOp() { .ParseParamsFn(ParseParams); std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; for (auto reg_data : reg_datas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); domi::OpRegistry::Instance()->Register(reg_data); } domi::OpRegistry::Instance()->registrationDatas.clear(); diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index b286546..d241730 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -176,7 +176,7 @@ void UtestTensorflowParser::RegisterCustomOp() { .ParseParamsFn(ParseParams); std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; for (auto reg_data : reg_datas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); domi::OpRegistry::Instance()->Register(reg_data); } domi::OpRegistry::Instance()->registrationDatas.clear(); @@ -599,7 +599,7 @@ namespace { void register_tbe_op() { std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; for (OpRegistrationData reg_data : registrationDatas) { - OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); OpRegistry::Instance()->Register(reg_data); } OpRegistry::Instance()->registrationDatas.clear();