Browse Source

!142 bugfix for shape op parser and parser code review

Merge pull request !142 from yangyongqiang/dev_parser_bugfix
pull/142/MERGE
i-robot Gitee 5 years ago
parent
commit
e45faf78e1
4 changed files with 33 additions and 4 deletions
  1. +12
    -2
      parser/caffe/caffe_parser.cc
  2. +1
    -0
      parser/common/acl_graph_parser_util.cc
  3. +10
    -0
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc
  4. +10
    -2
      parser/tensorflow/tensorflow_parser.cc

+ 12
- 2
parser/caffe/caffe_parser.cc View File

@@ -83,7 +83,11 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);
domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
if (status != domi::SUCCESS) {
GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
return GRAPH_FAILED;
}

// Create an empty computegraph
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
@@ -102,6 +106,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
GELOGI("Parser graph %s success.", graph.GetName().c_str());

auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE);
GE_CHECK_NOTNULL(weights_parser);
ret = weights_parser->Parse(weights_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str());
@@ -125,7 +130,11 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);
domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
if (status != domi::SUCCESS) {
GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
return GRAPH_FAILED;
}

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
@@ -155,6 +164,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
}

auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE);
GE_CHECK_NOTNULL(weights_parser);
ret = weights_parser->Parse(weights_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str());


+ 1
- 0
parser/common/acl_graph_parser_util.cc View File

@@ -1092,6 +1092,7 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
const std::map<AscendString, AscendString> &parser_params) {
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf,
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);

string input_fp16_nodes;
GetAclParams(parser_params, ge::ir_option::INPUT_FP16_NODES, input_fp16_nodes);


+ 10
- 0
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc View File

@@ -32,6 +32,8 @@ using ge::parser::PLACEHOLDERWITHDEFAULT;
namespace ge {
namespace {
const char *const kTfAttrT = "T";
const char *const kShapeAttrOutType = "out_type";
const char *const kShapeAttrDtype = "dtype";
} // namespace

Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
@@ -71,6 +73,14 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge

// add nodedef for shape insert by adapter when online_infer_dynamic
if (op_dest->GetType() == SHAPE) {
ge::DataType out_type = DT_INT32;
if (AttrUtils::GetDataType(op_dest, kShapeAttrOutType, out_type)) {
if (!AttrUtils::SetInt(op_dest, kShapeAttrDtype, static_cast<int64_t>(out_type))) {
GELOGE(FAILED, "Set attr dtype for op:%s failed.", op_dest->GetName().c_str());
return FAILED;
}
}

std::shared_ptr<NodeDef> pkg_node = ge::parser::MakeShared<NodeDef>();
GE_CHECK_NOTNULL(pkg_node);
pkg_node->CopyFrom(*node);


+ 10
- 2
parser/tensorflow/tensorflow_parser.cc View File

@@ -97,7 +97,11 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);
domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
if (status != domi::SUCCESS) {
GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
return GRAPH_FAILED;
}

// Create an empty computegraph
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
@@ -132,7 +136,11 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<Ascend

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);
domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
if (status != domi::SUCCESS) {
GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
return GRAPH_FAILED;
}

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {


Loading…
Cancel
Save