| @@ -48,20 +48,38 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso | |||||
| switch (onnx_tensor.data_type()) { | switch (onnx_tensor.data_type()) { | ||||
| case onnx::TensorProto_DataType_FLOAT: | case onnx::TensorProto_DataType_FLOAT: | ||||
| *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); | *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); | ||||
| for (size_t i = 0; i < data_count; i++) { | |||||
| value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]); | |||||
| if (onnx_tensor.float_data_size() > 0) { | |||||
| for (int i = 0; i < onnx_tensor.float_data_size(); i++) { | |||||
| value->push_back(onnx_tensor.float_data(i)); | |||||
| } | |||||
| } else { | |||||
| for (size_t i = 0; i < data_count; i++) { | |||||
| value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]); | |||||
| } | |||||
| } | } | ||||
| break; | break; | ||||
| case onnx::TensorProto_DataType_INT32: | case onnx::TensorProto_DataType_INT32: | ||||
| *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); | *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); | ||||
| for (size_t i = 0; i < data_count; i++) { | |||||
| value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i])); | |||||
| if (onnx_tensor.int32_data_size() > 0) { | |||||
| for (int i = 0; i < onnx_tensor.int32_data_size(); i++) { | |||||
| value->push_back(onnx_tensor.int32_data(i)); | |||||
| } | |||||
| } else { | |||||
| for (size_t i = 0; i < data_count; i++) { | |||||
| value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i])); | |||||
| } | |||||
| } | } | ||||
| break; | break; | ||||
| case onnx::TensorProto_DataType_INT64: | case onnx::TensorProto_DataType_INT64: | ||||
| *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); | *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); | ||||
| for (size_t i = 0; i < data_count; i++) { | |||||
| value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i])); | |||||
| if (onnx_tensor.int64_data_size() > 0) { | |||||
| for (int i = 0; i < onnx_tensor.int64_data_size(); i++) { | |||||
| value->push_back(onnx_tensor.int64_data(i)); | |||||
| } | |||||
| } else { | |||||
| for (size_t i = 0; i < data_count; i++) { | |||||
| value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i])); | |||||
| } | |||||
| } | } | ||||
| break; | break; | ||||
| default: | default: | ||||