From 085b3065c33530e7897e62762c2d00c8a6279483 Mon Sep 17 00:00:00 2001 From: y00500818 Date: Tue, 15 Dec 2020 20:25:07 +0800 Subject: [PATCH] bugfix for shape op parser and parser code review --- parser/caffe/caffe_parser.cc | 14 ++++++++++++-- parser/common/acl_graph_parser_util.cc | 1 + .../tensorflow_auto_mapping_parser_adapter.cc | 10 ++++++++++ parser/tensorflow/tensorflow_parser.cc | 12 ++++++++++-- 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 18b5d9e..375c1c4 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -83,7 +83,11 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, // load custom plugin so and proto AclGrphParseUtil acl_graph_parse_util; - (void)acl_graph_parse_util.AclParserInitialize(options); + domi::Status status = acl_graph_parse_util.AclParserInitialize(options); + if (status != domi::SUCCESS) { + GELOGE(GRAPH_FAILED, "Parser Initialize failed."); + return GRAPH_FAILED; + } // Create an empty computegraph ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); @@ -102,6 +106,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, GELOGI("Parser graph %s success.", graph.GetName().c_str()); auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE); + GE_CHECK_NOTNULL(weights_parser); ret = weights_parser->Parse(weights_file, graph); if (ret != ge::SUCCESS) { GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str()); @@ -125,7 +130,11 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, // load custom plugin so and proto AclGrphParseUtil acl_graph_parse_util; - (void)acl_graph_parse_util.AclParserInitialize(options); + domi::Status status = acl_graph_parse_util.AclParserInitialize(options); + if (status != domi::SUCCESS) { + GELOGE(GRAPH_FAILED, "Parser Initialize failed."); + return GRAPH_FAILED; + } string output_name; if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { @@ -155,6 +164,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, } auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE); + GE_CHECK_NOTNULL(weights_parser); ret = weights_parser->Parse(weights_file, graph); if (ret != ge::SUCCESS) { GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str()); diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index 0b683bd..2ff22dd 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -1092,6 +1092,7 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, const std::map &parser_params) { // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf, ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); string input_fp16_nodes; GetAclParams(parser_params, ge::ir_option::INPUT_FP16_NODES, input_fp16_nodes); diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc index 47ab56b..4d40e10 100644 --- a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc @@ -32,6 +32,8 @@ using ge::parser::PLACEHOLDERWITHDEFAULT; namespace ge { namespace { const char *const kTfAttrT = "T"; +const char *const kShapeAttrOutType = "out_type"; +const char *const kShapeAttrDtype = "dtype"; } // namespace Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { @@ -71,6 +73,14 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge // add nodedef for shape insert by adapter when online_infer_dynamic if (op_dest->GetType() == SHAPE) { + ge::DataType out_type = DT_INT32; + if (AttrUtils::GetDataType(op_dest, kShapeAttrOutType, out_type)) { + if (!AttrUtils::SetInt(op_dest, kShapeAttrDtype, static_cast(out_type))) { + GELOGE(FAILED, "Set attr dtype for op:%s failed.", op_dest->GetName().c_str()); + return FAILED; + } + } + std::shared_ptr pkg_node = ge::parser::MakeShared(); GE_CHECK_NOTNULL(pkg_node); pkg_node->CopyFrom(*node); diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 25f6734..3791ace 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -97,7 +97,11 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { // load custom plugin so and proto AclGrphParseUtil acl_graph_parse_util; - (void)acl_graph_parse_util.AclParserInitialize(options); + domi::Status status = acl_graph_parse_util.AclParserInitialize(options); + if (status != domi::SUCCESS) { + GELOGE(GRAPH_FAILED, "Parser Initialize failed."); + return GRAPH_FAILED; + } // Create an empty computegraph ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); @@ -132,7 +136,11 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map