diff --git a/parser/onnx/onnx_constant_parser.cc b/parser/onnx/onnx_constant_parser.cc index f373a50..9b9b0f1 100644 --- a/parser/onnx/onnx_constant_parser.cc +++ b/parser/onnx/onnx_constant_parser.cc @@ -142,7 +142,6 @@ Status OnnxConstantParser::ParseConvertTensor(const ge::onnx::TensorProto &tenso }; TensorDesc tensor_desc = tensor.GetTensorDesc(); tensor_desc.SetShape(ge::Shape(tmp_shape)); - tensor_desc.SetFormat(static_cast(GetParserContext().format)); tensor.SetTensorDesc(tensor_desc); // set data @@ -190,9 +189,6 @@ Status OnnxConstantParser::ParseConstFromInput(const ge::onnx::NodeProto *op_src } op_def.SetAttr(ge::kAttrNameValue, tensor); - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_def); - op_def.UpdateOutputDesc(op_desc->GetOutputNameByIndex(0), tensor.GetTensorDesc()); - return SUCCESS; } diff --git a/parser/onnx/onnx_data_parser.cc b/parser/onnx/onnx_data_parser.cc index f35b9d4..29f966a 100644 --- a/parser/onnx/onnx_data_parser.cc +++ b/parser/onnx/onnx_data_parser.cc @@ -42,7 +42,6 @@ Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) } ge::TensorDesc tensor_desc; - tensor_desc.SetFormat(static_cast(GetParserContext().format)); tensor_desc.SetShape(ge::Shape(user_input_dims_v_)); int64_t type = 1; (void)op_def.GetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, type); diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index f6f1394..a2dcaf3 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -414,39 +414,6 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { return SUCCESS; } -void OnnxModelParser::UpdateFormat(ge::Graph &graph) { - std::vector vec_op_name; - graph.GetAllOpName(vec_op_name); - ge::Format format = ge::FORMAT_NCHW; - for (string name: vec_op_name) { - ge::Operator op; - graph.FindOpByName(name, op); - auto op_dsc = ge::OpDescUtils::GetOpDescFromOperator(op); - if (std::find(kNoNeedUpdateFormat.begin(), kNoNeedUpdateFormat.end(), op_dsc->GetType()) - != kNoNeedUpdateFormat.end()) { - GELOGW("Op %s:%s no need update format.", op_dsc->GetName().c_str(), op_dsc->GetType().c_str()); - continue; - } - auto input_size = op_dsc->GetAllInputsSize(); - for (size_t i = 0; i < input_size; i++) { - auto input = op_dsc->MutableInputDesc(static_cast(i)); - if (input == nullptr) { - continue; - } - input->SetFormat(format); - input->SetOriginFormat(format); - } - - auto output_size = op_dsc->GetOutputsSize(); - for (size_t i = 0; i < output_size; i++) { - auto output = op_dsc->GetOutputDesc(i); - output.SetFormat(format); - output.SetOriginFormat(format); - op_dsc->UpdateOutputDesc(i, output); - } - } -} - Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { for (int i = 0; i < onnx_graph.node_size(); i++) { ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); @@ -633,8 +600,7 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model graph.SetInputs(input_ops); GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); - - UpdateFormat(graph); + UpdateDataFormat(graph); GELOGI("Onnx model parser success."); return SUCCESS; @@ -695,6 +661,29 @@ ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { return ge::OnnxUtil::ConvertOnnxDataType(type); } +void OnnxModelParser::UpdateDataFormat(ge::Graph &graph) { + for (GNode &gn : graph.GetDirectNode()) { + AscendString type; + (void)gn.GetType(type); + if (type != parser::DATA) { + continue; + } + TensorDesc in_desc; + gn.GetInputDesc(0, in_desc); + in_desc.SetOriginFormat(static_cast(GetParserContext().format)); + in_desc.SetFormat(static_cast(GetParserContext().format)); + gn.UpdateInputDesc(0, in_desc); + + TensorDesc out_desc; + gn.GetOutputDesc(0, out_desc); + out_desc.SetOriginFormat(static_cast(GetParserContext().format)); + out_desc.SetFormat(static_cast(GetParserContext().format)); + gn.UpdateOutputDesc(0, out_desc); + } + GELOGD("Update data format success."); + return; +} + } // namespace domi namespace domi { diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index 9c9eeee..45adf7c 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -100,7 +100,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); - void UpdateFormat(ge::Graph &graph); + void UpdateDataFormat(ge::Graph &graph); std::map ori_to_om_type_;