From 069eb8b23c1f6bb600fdebc2403069e4e8a2461c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E5=81=A5?= Date: Mon, 29 Aug 2022 12:40:53 +0000 Subject: [PATCH] =?UTF-8?q?!637=20=E8=A7=A3=E6=9E=90onnx=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=AD=A4=E7=AB=8Bconst=E8=8A=82=E7=82=B9=20M?= =?UTF-8?q?erge=20pull=20request=20!637=20from=20=E9=9F=A9=E5=81=A5/hanjia?= =?UTF-8?q?n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- parser/onnx/onnx_parser.cc | 13 ++++++------- .../onnx_parser_unittest.cc | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 3a88315..75472a2 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -671,13 +671,12 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: } Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector &input_ops) { - if (input_node_names_.empty()) { - // subgraph might not have input, we use constant nodes as the start nodes of graph - for (int i = 0; i < onnx_graph.node_size(); i++) { - ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); - if (node->op_type() == kOpTypeConstant) { - input_node_names_.emplace_back(node->name()); - } + // subgraph might not have input, or isolated const nodes exist in the graph, + // we use constant nodes as the start nodes of graph + for (int i = 0; i < onnx_graph.node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); + if (node->op_type() == kOpTypeConstant) { + input_node_names_.emplace_back(node->name()); } } for (auto in_name : input_node_names_) { diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index 64669cc..6fdee9e 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -703,4 +703,23 @@ TEST_F(UtestOnnxParser, onnx_test_TransNodeToOperator_SetTensorData) EXPECT_EQ(ret, SUCCESS); } +TEST_F(UtestOnnxParser, onnx_test_const_input_op) +{ + ge::onnx::ModelProto model_proto; + ge::onnx::GraphProto* graph = model_proto.mutable_graph(); + ge::onnx::NodeProto *node_proto = graph->add_node(); + node_proto->set_op_type("Constant"); + node_proto->set_domain("const.onnx"); + node_proto->set_name("const_11"); + ge::OpDescPtr op_desc_src = std::make_shared("Constant", "const.onnx"); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); + std::string op_type = "Constant"; + + OnnxModelParser onnx_parser; + std::vector input_ops; + onnx_parser.name_operator_["const_11"] = op; + Status ret = onnx_parser.GetGraphInputs(*graph, input_ops); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(input_ops.size() > 0, true); +} } // namespace ge