Browse Source

!259 update onnx format

Merge pull request !259 from yangyongqiang/onnx_format
pull/259/MERGE
i-robot Gitee 4 years ago
parent
commit
185be26953
4 changed files with 25 additions and 41 deletions
  1. +0
    -4
      parser/onnx/onnx_constant_parser.cc
  2. +0
    -1
      parser/onnx/onnx_data_parser.cc
  3. +24
    -35
      parser/onnx/onnx_parser.cc
  4. +1
    -1
      parser/onnx/onnx_parser.h

+ 0
- 4
parser/onnx/onnx_constant_parser.cc View File

@@ -142,7 +142,6 @@ Status OnnxConstantParser::ParseConvertTensor(const ge::onnx::TensorProto &tenso
};
TensorDesc tensor_desc = tensor.GetTensorDesc();
tensor_desc.SetShape(ge::Shape(tmp_shape));
tensor_desc.SetFormat(static_cast<ge::Format>(GetParserContext().format));
tensor.SetTensorDesc(tensor_desc);

// set data
@@ -190,9 +189,6 @@ Status OnnxConstantParser::ParseConstFromInput(const ge::onnx::NodeProto *op_src
}

op_def.SetAttr(ge::kAttrNameValue, tensor);
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_def);
op_def.UpdateOutputDesc(op_desc->GetOutputNameByIndex(0), tensor.GetTensorDesc());

return SUCCESS;
}



+ 0
- 1
parser/onnx/onnx_data_parser.cc View File

@@ -42,7 +42,6 @@ Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def)
}

ge::TensorDesc tensor_desc;
tensor_desc.SetFormat(static_cast<ge::Format>(GetParserContext().format));
tensor_desc.SetShape(ge::Shape(user_input_dims_v_));
int64_t type = 1;
(void)op_def.GetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, type);


+ 24
- 35
parser/onnx/onnx_parser.cc View File

@@ -414,39 +414,6 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) {
return SUCCESS;
}

void OnnxModelParser::UpdateFormat(ge::Graph &graph) {
std::vector<string> vec_op_name;
graph.GetAllOpName(vec_op_name);
ge::Format format = ge::FORMAT_NCHW;
for (string name: vec_op_name) {
ge::Operator op;
graph.FindOpByName(name, op);
auto op_dsc = ge::OpDescUtils::GetOpDescFromOperator(op);
if (std::find(kNoNeedUpdateFormat.begin(), kNoNeedUpdateFormat.end(), op_dsc->GetType())
!= kNoNeedUpdateFormat.end()) {
GELOGW("Op %s:%s no need update format.", op_dsc->GetName().c_str(), op_dsc->GetType().c_str());
continue;
}
auto input_size = op_dsc->GetAllInputsSize();
for (size_t i = 0; i < input_size; i++) {
auto input = op_dsc->MutableInputDesc(static_cast<uint32_t>(i));
if (input == nullptr) {
continue;
}
input->SetFormat(format);
input->SetOriginFormat(format);
}

auto output_size = op_dsc->GetOutputsSize();
for (size_t i = 0; i < output_size; i++) {
auto output = op_dsc->GetOutputDesc(i);
output.SetFormat(format);
output.SetOriginFormat(format);
op_dsc->UpdateOutputDesc(i, output);
}
}
}

Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
@@ -633,8 +600,7 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
graph.SetInputs(input_ops);

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));

UpdateFormat(graph);
UpdateDataFormat(graph);

GELOGI("Onnx model parser success.");
return SUCCESS;
@@ -695,6 +661,29 @@ ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) {
return ge::OnnxUtil::ConvertOnnxDataType(type);
}

void OnnxModelParser::UpdateDataFormat(ge::Graph &graph) {
for (GNode &gn : graph.GetDirectNode()) {
AscendString type;
(void)gn.GetType(type);
if (type != parser::DATA) {
continue;
}
TensorDesc in_desc;
gn.GetInputDesc(0, in_desc);
in_desc.SetOriginFormat(static_cast<ge::Format>(GetParserContext().format));
in_desc.SetFormat(static_cast<ge::Format>(GetParserContext().format));
gn.UpdateInputDesc(0, in_desc);

TensorDesc out_desc;
gn.GetOutputDesc(0, out_desc);
out_desc.SetOriginFormat(static_cast<ge::Format>(GetParserContext().format));
out_desc.SetFormat(static_cast<ge::Format>(GetParserContext().format));
gn.UpdateOutputDesc(0, out_desc);
}
GELOGD("Update data format success.");
return;
}

} // namespace domi

namespace domi {


+ 1
- 1
parser/onnx/onnx_parser.h View File

@@ -100,7 +100,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph);

void UpdateFormat(ge::Graph &graph);
void UpdateDataFormat(ge::Graph &graph);

std::map<std::string, std::string> ori_to_om_type_;



Loading…
Cancel
Save