Browse Source

Pre Merge pull request !654 from likun104/br_parser_test

pull/654/MERGE
likun104 Gitee 3 years ago
parent
commit
b27d816808
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 91 additions and 11 deletions
  1. +82
    -8
      parser/common/tbe_plugin_loader.cc
  2. +4
    -0
      parser/common/tbe_plugin_loader.h
  3. +1
    -0
      parser/tensorflow/tensorflow_parser.cc
  4. +4
    -3
      tests/st/testcase/test_tensorflow_parser.cc

+ 82
- 8
parser/common/tbe_plugin_loader.cc View File

@@ -29,6 +29,7 @@
#include <map>
#include <memory>
#include <string>
#include <regex>

#include "external/ge/ge_api_types.h"
#include "common/util/error_manager/error_manager.h"
@@ -37,8 +38,14 @@
#include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/utils/type_utils.h"
#include "parser/common/acl_graph_parser_util.h"
// #include "mmpa/mmpa_api.h"
// #include "common/checker.h"

namespace ge {
namespace {
const char_t *const kVendors = "vendors"; // opp vendors directory name
const char_t *const kConfig = "config.ini"; // opp vendors config file name
} // namespace
std::map<string, string> TBEPluginLoader::options_ = {};

// Get Singleton Instance
@@ -101,6 +108,65 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPlugi
}
}

Status TBEPluginLoader::GetOppPluginVendors(const std::string &opp_path, std::vector<std::string> &vendors) {
GELOGI("Enter get opp plugin config file schedule");
// GE_ASSERT_TRUE(!opp_path.empty(), "[Check]Value of opp_path should not be empty!");
const std::string config = opp_path + kVendors + '/' + kConfig;
std::ifstream infile(config);
if (!infile.good()) {
GELOGW("File '%s' open failed!", config.c_str());
return FAILED;
}
std::string content;
std::getline(infile, content);
infile.close();
if (content.empty()) {
GELOGW("Content of file '%s' is empty!", config.c_str());
return FAILED;
}
content = std::regex_replace(content, std::regex("^\\w+="), "");
const std::regex sep("\\s*,\\s*");
std::sregex_token_iterator stpos(content.begin(), content.end(), sep, -1);
std::sregex_token_iterator end;
for(; stpos != end; ++stpos){
vendors.push_back(stpos->str());
}
return SUCCESS;
}

Status TBEPluginLoader::GetOpsProtoPath(const std::string &path, const std::string &subdir, std::string &opsproto_path,
bool keep_subdir) {
// GELOGI("Enter get ops proto path schedule");
// GE_ASSERT_TRUE(!path.empty(), "[Check]Value of path should not be empty!");
std::string opp_path = path;
// if (opp_path.back() != '/') {
// opp_path += '/';
// }
const std::string fmt = std::regex_search(subdir, std::regex("%s")) ? subdir : (subdir + "/%s");
const std::string fmtCustom = keep_subdir ? fmt : std::regex_replace(fmt, std::regex("%s/.*"), "%s");
// if (mmIsDir((opp_path + kVendors).c_str()) != EN_OK) {
if (true) {
// GELOGI("Opp plugin path is old version!");
opsproto_path = (opp_path + std::regex_replace(fmtCustom, std::regex("%s"), "custom") + "/:")
+ (opp_path + std::regex_replace(fmt, std::regex("%s"), "built-in"));
} else {
// GELOGI("Opp plugin path is new version!");
std::vector<std::string> vendors;
if (GetOppPluginVendors(opp_path, vendors) != SUCCESS) {
// GELOGW("Failed to get opp plugin vendors!");
return FAILED;
}
const std::regex reg("^(.+?)/(%s)");
const std::string fmtNew = std::regex_replace(fmt, reg, "$2/$1");
const std::string fmtNewCustom = keep_subdir ? fmtNew : std::regex_replace(fmtNew, std::regex("(%s/.+?)/.*"), "$1");
for (const auto &vendor : vendors) {
opsproto_path += opp_path + kVendors + "/" + std::regex_replace(fmtNewCustom, std::regex("%s"), vendor) + "/:";
}
opsproto_path += opp_path + std::regex_replace(fmtNew, std::regex("%s"), "built-in");
}
return SUCCESS;
}

void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) {
GELOGI("Enter get custom op path schedule");
std::string fmk_type;
@@ -112,18 +178,26 @@ void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) {
fmk_type = ge::TypeUtils::FmkTypeToSerialString(type);
GELOGI("Framework type is %s.", fmk_type.c_str());

std::string path;
const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
std::string path = path_env;
customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
path = path_env;
GELOGI("Get custom so path from env : %s", path_env);
return;
} else {
path = GetPath();
GELOGI("path is %s", path.c_str());
path = path.substr(0, path.rfind('/'));
path = path.substr(0, path.rfind('/') + 1);
path += "ops";
}
std::string path_base = GetPath();
GELOGI("path_base is %s", path_base.c_str());
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type);
// customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
Status ret = GetOpsProtoPath(path, "framework/%s/" + fmk_type, customop_path, false);
std::cout << "parser/common/tbe_plugin_loader.cc:customop_path.1=" << customop_path << std::endl;
// if (ret != SUCCESS) {
// GELOGW("Failed to get opp proto path!");
// }
customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
std::cout << "parser/common/tbe_plugin_loader.cc:customop_path.2=" << customop_path << std::endl;
}

string TBEPluginLoader::GetPath() {


+ 4
- 0
parser/common/tbe_plugin_loader.h View File

@@ -40,12 +40,16 @@ public:

static string GetPath();

static Status GetOpsProtoPath(const std::string &path, const std::string &subdir, std::string &opsproto_path,
bool keep_subdir = true);

private:
TBEPluginLoader() = default;
~TBEPluginLoader() = default;
Status ClearHandles_();
static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff);
static Status GetOppPluginVendors(const std::string &opp_path, std::vector<std::string> &vendors);
static void GetCustomOpPath(std::string &customop_path);
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path);


+ 1
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -1223,6 +1223,7 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
// Store objects parsed from pb files
domi::tensorflow::GraphDef OriDef;

std::cout << ":data:=" << data << std::endl;
bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef);
if (!read) {
REPORT_INNER_ERROR("E19999", "read graph proto from binary failed");


+ 4
- 3
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -1139,24 +1139,25 @@ TEST_F(STestTensorflowParser, tensorflow_parser_to_json)

TEST_F(STestTensorflowParser, tensorflow_parserfrommemory_failed)
{
dlog_setlevel(0, 0, 0);
TensorFlowModelParser modelParser;
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
std::string modelFile = caseDir + "/origin_models/tf_add.pb";
const char *data = modelFile.c_str();
// const char *data = modelFile.c_str();
uint32_t size = 1;
ge::Graph graph;
std::map<ge::AscendString, ge::AscendString> parser_params;
Status ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
ASSERT_EQ(ret, SUCCESS);

modelFile = caseDir + "/origin_models/tf_add.pb";
parser_params = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}};
ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
ret = modelParser.ParseFromMemory(data, size, compute_graph);
ret = modelParser.ParseFromMemory(modelFile.c_str(), size, compute_graph);
EXPECT_NE(ret, SUCCESS);
dlog_setlevel(0, 3, 0);
}

TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success)


Loading…
Cancel
Save