| @@ -143,9 +143,11 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| std::map<std::string, std::string> kOnnxOpMap = { | |||||
| const std::map<std::string, std::string> kOnnxOpMap = { | |||||
| {ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, | {ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, | ||||
| }; | }; | ||||
| const char* const MATMULV2 = "MatMulV2"; | |||||
| const std::vector<std::string> kNoNeedUpdateFormat = {MATMULV2}; | |||||
| } | } | ||||
| Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | ||||
| @@ -419,6 +421,11 @@ void OnnxModelParser::UpdateFormat(ge::Graph &graph) { | |||||
| ge::Operator op; | ge::Operator op; | ||||
| graph.FindOpByName(name, op); | graph.FindOpByName(name, op); | ||||
| auto op_dsc = ge::OpDescUtils::GetOpDescFromOperator(op); | auto op_dsc = ge::OpDescUtils::GetOpDescFromOperator(op); | ||||
| if (std::find(kNoNeedUpdateFormat.begin(), kNoNeedUpdateFormat.end(), op_dsc->GetType()) | |||||
| != kNoNeedUpdateFormat.end()) { | |||||
| GELOGW("Op %s:%s no need update format.", op_dsc->GetName().c_str(), op_dsc->GetType().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto input_size = op_dsc->GetAllInputsSize(); | auto input_size = op_dsc->GetAllInputsSize(); | ||||
| for (size_t i = 0; i < input_size; i++) { | for (size_t i = 0; i < input_size; i++) { | ||||
| auto input = op_dsc->MutableInputDesc(static_cast<uint32_t>(i)); | auto input = op_dsc->MutableInputDesc(static_cast<uint32_t>(i)); | ||||