diff --git a/parser/common/tbe_plugin_loader.cc b/parser/common/tbe_plugin_loader.cc index 2389852..22969d2 100644 --- a/parser/common/tbe_plugin_loader.cc +++ b/parser/common/tbe_plugin_loader.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "external/ge/ge_api_types.h" #include "common/util/error_manager/error_manager.h" @@ -37,8 +38,14 @@ #include "framework/omg/parser/parser_inner_ctx.h" #include "graph/utils/type_utils.h" #include "parser/common/acl_graph_parser_util.h" +// #include "mmpa/mmpa_api.h" +// #include "common/checker.h" namespace ge { +namespace { +const char_t *const kVendors = "vendors"; // opp vendors directory name +const char_t *const kConfig = "config.ini"; // opp vendors config file name +} // namespace std::map TBEPluginLoader::options_ = {}; // Get Singleton Instance @@ -101,6 +108,65 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPlugi } } +Status TBEPluginLoader::GetOppPluginVendors(const std::string &opp_path, std::vector &vendors) { + GELOGI("Enter get opp plugin config file schedule"); + // GE_ASSERT_TRUE(!opp_path.empty(), "[Check]Value of opp_path should not be empty!"); + const std::string config = opp_path + kVendors + '/' + kConfig; + std::ifstream infile(config); + if (!infile.good()) { + GELOGW("File '%s' open failed!", config.c_str()); + return FAILED; + } + std::string content; + std::getline(infile, content); + infile.close(); + if (content.empty()) { + GELOGW("Content of file '%s' is empty!", config.c_str()); + return FAILED; + } + content = std::regex_replace(content, std::regex("^\\w+="), ""); + const std::regex sep("\\s*,\\s*"); + std::sregex_token_iterator stpos(content.begin(), content.end(), sep, -1); + std::sregex_token_iterator end; + for(; stpos != end; ++stpos){ + vendors.push_back(stpos->str()); + } + return SUCCESS; +} + +Status TBEPluginLoader::GetOpsProtoPath(const std::string &path, const std::string &subdir, std::string &opsproto_path, + bool keep_subdir) { + // GELOGI("Enter get ops proto path schedule"); + // GE_ASSERT_TRUE(!path.empty(), "[Check]Value of path should not be empty!"); + std::string opp_path = path; + // if (opp_path.back() != '/') { + // opp_path += '/'; + // } + const std::string fmt = std::regex_search(subdir, std::regex("%s")) ? subdir : (subdir + "/%s"); + const std::string fmtCustom = keep_subdir ? fmt : std::regex_replace(fmt, std::regex("%s/.*"), "%s"); + // if (mmIsDir((opp_path + kVendors).c_str()) != EN_OK) { + if (true) { + // GELOGI("Opp plugin path is old version!"); + opsproto_path = (opp_path + std::regex_replace(fmtCustom, std::regex("%s"), "custom") + "/:") + + (opp_path + std::regex_replace(fmt, std::regex("%s"), "built-in")); + } else { + // GELOGI("Opp plugin path is new version!"); + std::vector vendors; + if (GetOppPluginVendors(opp_path, vendors) != SUCCESS) { + // GELOGW("Failed to get opp plugin vendors!"); + return FAILED; + } + const std::regex reg("^(.+?)/(%s)"); + const std::string fmtNew = std::regex_replace(fmt, reg, "$2/$1"); + const std::string fmtNewCustom = keep_subdir ? fmtNew : std::regex_replace(fmtNew, std::regex("(%s/.+?)/.*"), "$1"); + for (const auto &vendor : vendors) { + opsproto_path += opp_path + kVendors + "/" + std::regex_replace(fmtNewCustom, std::regex("%s"), vendor) + "/:"; + } + opsproto_path += opp_path + std::regex_replace(fmtNew, std::regex("%s"), "built-in"); + } + return SUCCESS; +} + void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { GELOGI("Enter get custom op path schedule"); std::string fmk_type; @@ -112,18 +178,26 @@ void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); GELOGI("Framework type is %s.", fmk_type.c_str()); + std::string path; const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { - std::string path = path_env; - customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); + path = path_env; GELOGI("Get custom so path from env : %s", path_env); - return; + } else { + path = GetPath(); + GELOGI("path is %s", path.c_str()); + path = path.substr(0, path.rfind('/')); + path = path.substr(0, path.rfind('/') + 1); + path += "ops"; } - std::string path_base = GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); + // customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); + Status ret = GetOpsProtoPath(path, "framework/%s/" + fmk_type, customop_path, false); + std::cout << "parser/common/tbe_plugin_loader.cc:customop_path.1=" << customop_path << std::endl; + // if (ret != SUCCESS) { + // GELOGW("Failed to get opp proto path!"); + // } + customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); + std::cout << "parser/common/tbe_plugin_loader.cc:customop_path.2=" << customop_path << std::endl; } string TBEPluginLoader::GetPath() { diff --git a/parser/common/tbe_plugin_loader.h b/parser/common/tbe_plugin_loader.h index b5adfb5..36c8500 100644 --- a/parser/common/tbe_plugin_loader.h +++ b/parser/common/tbe_plugin_loader.h @@ -40,12 +40,16 @@ public: static string GetPath(); + static Status GetOpsProtoPath(const std::string &path, const std::string &subdir, std::string &opsproto_path, + bool keep_subdir = true); + private: TBEPluginLoader() = default; ~TBEPluginLoader() = default; Status ClearHandles_(); static void ProcessSoFullName(vector &file_list, string &caffe_parser_path, string &full_name, const string &caffe_parser_so_suff); + static Status GetOppPluginVendors(const std::string &opp_path, std::vector &vendors); static void GetCustomOpPath(std::string &customop_path); static void GetPluginSoFileList(const string &path, vector &file_list, string &caffe_parser_path); static void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path); diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index cf71279..3714121 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -1223,6 +1223,7 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g // Store objects parsed from pb files domi::tensorflow::GraphDef OriDef; + std::cout << ":data:=" << data << std::endl; bool read = ge::parser::ReadProtoFromArray(data, static_cast(size), &OriDef); if (!read) { REPORT_INNER_ERROR("E19999", "read graph proto from binary failed"); diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index b5f1908..e593fcd 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -1139,24 +1139,25 @@ TEST_F(STestTensorflowParser, tensorflow_parser_to_json) TEST_F(STestTensorflowParser, tensorflow_parserfrommemory_failed) { + dlog_setlevel(0, 0, 0); TensorFlowModelParser modelParser; std::string caseDir = __FILE__; std::size_t idx = caseDir.find_last_of("/"); caseDir = caseDir.substr(0, idx); std::string modelFile = caseDir + "/origin_models/tf_add.pb"; - const char *data = modelFile.c_str(); + // const char *data = modelFile.c_str(); uint32_t size = 1; ge::Graph graph; std::map parser_params; Status ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); ASSERT_EQ(ret, SUCCESS); - modelFile = caseDir + "/origin_models/tf_add.pb"; parser_params = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); - ret = modelParser.ParseFromMemory(data, size, compute_graph); + ret = modelParser.ParseFromMemory(modelFile.c_str(), size, compute_graph); EXPECT_NE(ret, SUCCESS); + dlog_setlevel(0, 3, 0); } TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success)