From 1196f14a49072cfc5cc9fe72006ea073a4af462d Mon Sep 17 00:00:00 2001 From: y00500818 Date: Mon, 14 Dec 2020 19:17:30 +0800 Subject: [PATCH] add validation of fmk type for plugin load. --- parser/common/acl_graph_parser_util.cc | 14 ++++++++++++-- parser/common/parser_api.cc | 11 ++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index a2b90da..c4ea63d 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -27,6 +27,7 @@ #include "ge/ge_api_types.h" #include "graph/opsproto_manager.h" +#include "graph/utils/type_utils.h" #include "omg/parser/parser_inner_ctx.h" #include "framework/common/debug/ge_log.h" #include "parser/common/register_tbe.h" @@ -206,11 +207,20 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::mapsecond; + GELOGD("frameworkType is %s", fmk_type.c_str()); std::vector registrationDatas = op_registry->registrationDatas; GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); for (OpRegistrationData ®_data : registrationDatas) { - (void)OpRegistrationTbe::Instance()->Finalize(reg_data, false); - domi::OpRegistry::Instance()->Register(reg_data); + if (ge::TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()) == fmk_type) { + (void)OpRegistrationTbe::Instance()->Finalize(reg_data, false); + (void)domi::OpRegistry::Instance()->Register(reg_data); + } } // set init status diff --git a/parser/common/parser_api.cc b/parser/common/parser_api.cc index 65f6061..94b5999 100644 --- a/parser/common/parser_api.cc +++ b/parser/common/parser_api.cc @@ -19,6 +19,7 @@ #include "common/ge/tbe_plugin_manager.h" #include "framework/common/debug/ge_log.h" +#include "graph/utils/type_utils.h" #include "parser/common/register_tbe.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "external/ge/ge_api_types.h" @@ -38,10 +39,18 @@ Status ParserInitialize(const std::map &options) { // load custom op plugin TBEPluginManager::Instance().LoadPluginSo(options); + std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(domi::TENSORFLOW); + auto it = options.find(ge::FRAMEWORK_TYPE); + if (it != options.end()) { + fmk_type = it->second; + } + GELOGD("frameworkType is %s", fmk_type.c_str()); std::vector registrationDatas = domi::OpRegistry::Instance()->registrationDatas; GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); for (OpRegistrationData ®_data : registrationDatas) { - (void)OpRegistrationTbe::Instance()->Finalize(reg_data, true); + if (ge::TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()) == fmk_type) { + (void)OpRegistrationTbe::Instance()->Finalize(reg_data, true); + } } auto iter = options.find(ge::OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES);