Browse Source

!637 解析onnx模型中的孤立const节点

Merge pull request !637 from 韩健/hanjian
pull/639/head
韩健 i-robot 3 years ago
parent
commit
069eb8b23c
2 changed files with 25 additions and 7 deletions
  1. +6
    -7
      parser/onnx/onnx_parser.cc
  2. +19
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 6
- 7
parser/onnx/onnx_parser.cc View File

@@ -671,13 +671,12 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
}

Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) {
if (input_node_names_.empty()) {
// subgraph might not have input, we use constant nodes as the start nodes of graph
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
if (node->op_type() == kOpTypeConstant) {
input_node_names_.emplace_back(node->name());
}
// subgraph might not have input, or isolated const nodes exist in the graph,
// we use constant nodes as the start nodes of graph
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
if (node->op_type() == kOpTypeConstant) {
input_node_names_.emplace_back(node->name());
}
}
for (auto in_name : input_node_names_) {


+ 19
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -703,4 +703,23 @@ TEST_F(UtestOnnxParser, onnx_test_TransNodeToOperator_SetTensorData)
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestOnnxParser, onnx_test_const_input_op)
{
ge::onnx::ModelProto model_proto;
ge::onnx::GraphProto* graph = model_proto.mutable_graph();
ge::onnx::NodeProto *node_proto = graph->add_node();
node_proto->set_op_type("Constant");
node_proto->set_domain("const.onnx");
node_proto->set_name("const_11");
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Constant", "const.onnx");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);
std::string op_type = "Constant";

OnnxModelParser onnx_parser;
std::vector<ge::Operator> input_ops;
onnx_parser.name_operator_["const_11"] = op;
Status ret = onnx_parser.GetGraphInputs(*graph, input_ops);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(input_ops.size() > 0, true);
}
} // namespace ge

Loading…
Cancel
Save