Browse Source

Feature:Support user options of aclgrphParse interface

pull/76/head
l00444296 5 years ago
parent
commit
9e2dd62f76
7 changed files with 1130 additions and 72 deletions
  1. +5
    -1
      inc/external/parser/caffe_parser.h
  2. +4
    -1
      inc/external/parser/tensorflow_parser.h
  3. +1
    -1
      metadef
  4. +56
    -1
      parser/caffe/caffe_parser.cc
  5. +983
    -63
      parser/common/acl_graph_parser_util.cc
  6. +34
    -4
      parser/common/acl_graph_parser_util.h
  7. +47
    -1
      parser/tensorflow/tensorflow_parser.cc

+ 5
- 1
inc/external/parser/caffe_parser.h View File

@@ -21,12 +21,16 @@
#include <string>
#include <vector>

#include "graph/ascend_string.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/graph.h"
#include "graph/types.h"

namespace ge {
graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph);

graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph);
} // namespace ge

#endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_

+ 4
- 1
inc/external/parser/tensorflow_parser.h View File

@@ -22,12 +22,15 @@
#include <string>
#include <vector>

#include "graph/ascend_string.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/graph.h"
#include "graph/types.h"

namespace ge {
graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph);
graphStatus aclgrphParseTensorFlow(const char *model_file,
const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph);
} // namespace ge

#endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 9bbce07b846858fa30ef2bd7c662894e20f83ef1
Subproject commit 37465b85d30b67a0edcc6ea4acd2f11a9697c7af

+ 56
- 1
parser/caffe/caffe_parser.cc View File

@@ -108,12 +108,67 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
return ret;
}
GELOGI("Weights parse success. graph: %s", graph.GetName().c_str());
if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) {
std::map<AscendString, AscendString> parser_params;
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
return ge::SUCCESS;
}

graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::CAFFE;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params before graph failed.");
return ge::FAILED;
}
// Create an empty computegraph
string graph_name = output_name.empty() ? "tmpGraph" : output_name;
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE);
GE_CHECK_NOTNULL(model_parser);

// parse caffe model_file and weights_file to GE graph
ge::graphStatus ret = model_parser->Parse(model_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());

if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
return ge::FAILED;
}

auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE);
ret = weights_parser->Parse(weights_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str());
return ret;
}
GELOGI("Weights parse success. graph: %s", graph.GetName().c_str());

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
}
} // namespace ge




+ 983
- 63
parser/common/acl_graph_parser_util.cc
File diff suppressed because it is too large
View File


+ 34
- 4
parser/common/acl_graph_parser_util.h View File

@@ -17,14 +17,18 @@
#ifndef ACL_GRAPH_PARSE_UTIL_
#define ACL_GRAPH_PARSE_UTIL_

#include <map>
#include <string>
#include <google/protobuf/text_format.h>

#include <map>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>

#include "framework/omg/parser/parser_types.h"
#include "register/register_error_codes.h"
#include "graph/ascend_string.h"
#include "graph/utils/graph_utils.h"
#include "register/register_error_codes.h"

namespace ge {

@@ -37,13 +41,39 @@ class AclGrphParseUtil {
domi::Status LoadOpsProtoLib();
void SaveCustomCaffeProtoPath();
domi::Status AclParserInitialize(const std::map<std::string, std::string> &options);
domi::Status SetDefaultOutputNode(ge::Graph &graph);
domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);
domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
std::string &graph_name);
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);
domi::Status ParseOutputInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);

private:
bool parser_initialized = false;
domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params);
domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
domi::Status ParseAclLogLevel(const std::string &log);
bool CheckAclInputFormat(string &input_format);
domi::Status ParseAclFormat(std::string &input_format);
bool ParseInputShape(const std::string &input_shape, std::unordered_map<std::string, vector<int64_t>> &shape_map,
vector<pair<std::string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input);
domi::Status ParseAclShape(const std::string &input_shape, bool is_dynamic_input);
domi::Status ParseAclOutputNodes(const std::string &out_nodes);
domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16);
domi::Status ParseAclOpConf(const std::string &op_conf);
domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes);
static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name,
uint32_t index, OpDescPtr &op_desc);
domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
const string &is_input_adjust_hw_layout);
domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf);
domi::Status ParseAclOutputType(const std::string &output_type,
std::map<std::string, vector<std::string>> &output_node_dt_map);
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
domi::Status CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input);
domi::Status CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf);
};

namespace parser {


+ 47
- 1
parser/tensorflow/tensorflow_parser.cc View File

@@ -113,13 +113,59 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {
return ge::FAILED;
}

if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) {
std::map<AscendString, AscendString> parser_params;
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
}

graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::TENSORFLOW;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::TENSORFLOW)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params before graph failed.");
return ge::FAILED;
}
// Create an empty computegraph
string graph_name = output_name.empty() ? "tmpGraph" : output_name;
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);
GE_CHECK_NOTNULL(model_parser);

// parse tensorflow model_file to GE graph
ge::graphStatus ret = model_parser->Parse(model_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
return ge::FAILED;
}

if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
return ge::FAILED;
}

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
}
} // namespace ge

namespace ge {


Loading…
Cancel
Save