|
|
|
@@ -414,39 +414,6 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxModelParser::UpdateFormat(ge::Graph &graph) { |
|
|
|
std::vector<string> vec_op_name; |
|
|
|
graph.GetAllOpName(vec_op_name); |
|
|
|
ge::Format format = ge::FORMAT_NCHW; |
|
|
|
for (string name: vec_op_name) { |
|
|
|
ge::Operator op; |
|
|
|
graph.FindOpByName(name, op); |
|
|
|
auto op_dsc = ge::OpDescUtils::GetOpDescFromOperator(op); |
|
|
|
if (std::find(kNoNeedUpdateFormat.begin(), kNoNeedUpdateFormat.end(), op_dsc->GetType()) |
|
|
|
!= kNoNeedUpdateFormat.end()) { |
|
|
|
GELOGW("Op %s:%s no need update format.", op_dsc->GetName().c_str(), op_dsc->GetType().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto input_size = op_dsc->GetAllInputsSize(); |
|
|
|
for (size_t i = 0; i < input_size; i++) { |
|
|
|
auto input = op_dsc->MutableInputDesc(static_cast<uint32_t>(i)); |
|
|
|
if (input == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
input->SetFormat(format); |
|
|
|
input->SetOriginFormat(format); |
|
|
|
} |
|
|
|
|
|
|
|
auto output_size = op_dsc->GetOutputsSize(); |
|
|
|
for (size_t i = 0; i < output_size; i++) { |
|
|
|
auto output = op_dsc->GetOutputDesc(i); |
|
|
|
output.SetFormat(format); |
|
|
|
output.SetOriginFormat(format); |
|
|
|
op_dsc->UpdateOutputDesc(i, output); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { |
|
|
|
for (int i = 0; i < onnx_graph.node_size(); i++) { |
|
|
|
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); |
|
|
|
@@ -633,8 +600,7 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
|
graph.SetInputs(input_ops); |
|
|
|
|
|
|
|
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); |
|
|
|
|
|
|
|
UpdateFormat(graph); |
|
|
|
UpdateDataFormat(graph); |
|
|
|
|
|
|
|
GELOGI("Onnx model parser success."); |
|
|
|
return SUCCESS; |
|
|
|
@@ -695,6 +661,29 @@ ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { |
|
|
|
return ge::OnnxUtil::ConvertOnnxDataType(type); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxModelParser::UpdateDataFormat(ge::Graph &graph) { |
|
|
|
for (GNode &gn : graph.GetDirectNode()) { |
|
|
|
AscendString type; |
|
|
|
(void)gn.GetType(type); |
|
|
|
if (type != parser::DATA) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
TensorDesc in_desc; |
|
|
|
gn.GetInputDesc(0, in_desc); |
|
|
|
in_desc.SetOriginFormat(static_cast<ge::Format>(GetParserContext().format)); |
|
|
|
in_desc.SetFormat(static_cast<ge::Format>(GetParserContext().format)); |
|
|
|
gn.UpdateInputDesc(0, in_desc); |
|
|
|
|
|
|
|
TensorDesc out_desc; |
|
|
|
gn.GetOutputDesc(0, out_desc); |
|
|
|
out_desc.SetOriginFormat(static_cast<ge::Format>(GetParserContext().format)); |
|
|
|
out_desc.SetFormat(static_cast<ge::Format>(GetParserContext().format)); |
|
|
|
gn.UpdateOutputDesc(0, out_desc); |
|
|
|
} |
|
|
|
GELOGD("Update data format success."); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace domi |
|
|
|
|
|
|
|
namespace domi { |
|
|
|
|