From 4ed8fea351004b379bd296ca17065f9645d0505c Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Sat, 30 Jan 2021 11:05:54 +0800 Subject: [PATCH] fix onnx constant parser --- .../converter/parser/onnx/onnx_node_parser.cc | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc index 5443d9ca9a..7da40b6f54 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -48,20 +48,38 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso switch (onnx_tensor.data_type()) { case onnx::TensorProto_DataType_FLOAT: *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); - for (size_t i = 0; i < data_count; i++) { - value->push_back(reinterpret_cast(onnx_tensor.raw_data().data())[i]); + if (onnx_tensor.float_data_size() > 0) { + for (int i = 0; i < onnx_tensor.float_data_size(); i++) { + value->push_back(onnx_tensor.float_data(i)); + } + } else { + for (size_t i = 0; i < data_count; i++) { + value->push_back(reinterpret_cast(onnx_tensor.raw_data().data())[i]); + } } break; case onnx::TensorProto_DataType_INT32: *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); - for (size_t i = 0; i < data_count; i++) { - value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); + if (onnx_tensor.int32_data_size() > 0) { + for (int i = 0; i < onnx_tensor.int32_data_size(); i++) { + value->push_back(onnx_tensor.int32_data(i)); + } + } else { + for (size_t i = 0; i < data_count; i++) { + value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); + } } break; case onnx::TensorProto_DataType_INT64: *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); - for (size_t i = 0; i < data_count; i++) { - value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); + if (onnx_tensor.int64_data_size() > 0) { + for (int i = 0; i < onnx_tensor.int64_data_size(); i++) { + value->push_back(onnx_tensor.int64_data(i)); + } + } else { + for (size_t i = 0; i < data_count; i++) { + value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); + } } break; default: