Browse Source

!11873 fix onnx const parser

From: @cjh9368
Reviewed-by: @zhanghaibo5,@hangangqiang
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
e19505cdf2
1 changed files with 24 additions and 6 deletions
  1. +24
    -6
      mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc

+ 24
- 6
mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc View File

@@ -48,20 +48,38 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso
switch (onnx_tensor.data_type()) { switch (onnx_tensor.data_type()) {
case onnx::TensorProto_DataType_FLOAT: case onnx::TensorProto_DataType_FLOAT:
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT);
for (size_t i = 0; i < data_count; i++) {
value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]);
if (onnx_tensor.float_data_size() > 0) {
for (int i = 0; i < onnx_tensor.float_data_size(); i++) {
value->push_back(onnx_tensor.float_data(i));
}
} else {
for (size_t i = 0; i < data_count; i++) {
value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]);
}
} }
break; break;
case onnx::TensorProto_DataType_INT32: case onnx::TensorProto_DataType_INT32:
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
for (size_t i = 0; i < data_count; i++) {
value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i]));
if (onnx_tensor.int32_data_size() > 0) {
for (int i = 0; i < onnx_tensor.int32_data_size(); i++) {
value->push_back(onnx_tensor.int32_data(i));
}
} else {
for (size_t i = 0; i < data_count; i++) {
value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i]));
}
} }
break; break;
case onnx::TensorProto_DataType_INT64: case onnx::TensorProto_DataType_INT64:
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
for (size_t i = 0; i < data_count; i++) {
value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i]));
if (onnx_tensor.int64_data_size() > 0) {
for (int i = 0; i < onnx_tensor.int64_data_size(); i++) {
value->push_back(onnx_tensor.int64_data(i));
}
} else {
for (size_t i = 0; i < data_count; i++) {
value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i]));
}
} }
break; break;
default: default:


Loading…
Cancel
Save