From 4aaab1b8d511a8a3298a60594f5ce083b36432b0 Mon Sep 17 00:00:00 2001 From: y00500818 Date: Wed, 27 Jan 2021 14:20:56 +0800 Subject: [PATCH] handle onnx input with same name as initializer --- parser/onnx/onnx_parser.cc | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index ae717e2..03e9af7 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -150,19 +150,19 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, GELOGE(FAILED, "Onnx graph has zero input"); return FAILED; } + // get input value info map - std::map input_name_tensor; for (int i = 0; i < onnx_graph.input_size(); i++) { ge::onnx::ValueInfoProto value_info = onnx_graph.input(i); GELOGI("The index of %d input name : %s.", i, value_info.name().c_str()); - // The input are possibly initialized by a default value found in ‘initializer.’ + /// if the input is initialized by a default value found in ‘initializer’, + /// it will be considered as a const node. auto initializer_iter = initializer_name_tensor.find(value_info.name()); if (initializer_iter != initializer_name_tensor.end()) { - input_name_tensor[value_info.name()] = initializer_iter->second; - initializer_name_tensor.erase(initializer_iter); continue; } + ge::onnx::TensorProto tensor_tmp; if (value_info.has_type()) { const ge::onnx::TypeProto type = value_info.type(); @@ -181,26 +181,21 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, } } } - input_name_tensor[value_info.name()] = tensor_tmp; - } - // Construct node for input - int64_t index = 0; - for (auto it : input_name_tensor) { + // Construct node for input ge::onnx::NodeProto *input_node = onnx_graph.add_node(); - input_node->set_name(it.first); + input_node->set_name(value_info.name()); input_node->set_op_type(ge::kOpTypeInput); - input_node->add_output(it.first); + input_node->add_output(value_info.name()); // add tensor ge::onnx::AttributeProto *attribute = input_node->add_attribute(); attribute->set_name(ge::kAttrNameInput); ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); - *attribute_tensor = it.second; + *attribute_tensor = tensor_tmp; // add index ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); attribute_index->set_name(ge::kAttrNameIndex); - attribute_index->set_i(index++); - - input_node_names_.emplace_back(it.first); + attribute_index->set_i(i); + input_node_names_.emplace_back(value_info.name()); } return SUCCESS; }