|
|
|
@@ -48,20 +48,38 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso |
|
|
|
switch (onnx_tensor.data_type()) { |
|
|
|
case onnx::TensorProto_DataType_FLOAT: |
|
|
|
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); |
|
|
|
for (size_t i = 0; i < data_count; i++) { |
|
|
|
value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]); |
|
|
|
if (onnx_tensor.float_data_size() > 0) { |
|
|
|
for (int i = 0; i < onnx_tensor.float_data_size(); i++) { |
|
|
|
value->push_back(onnx_tensor.float_data(i)); |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < data_count; i++) { |
|
|
|
value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
break; |
|
|
|
case onnx::TensorProto_DataType_INT32: |
|
|
|
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); |
|
|
|
for (size_t i = 0; i < data_count; i++) { |
|
|
|
value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i])); |
|
|
|
if (onnx_tensor.int32_data_size() > 0) { |
|
|
|
for (int i = 0; i < onnx_tensor.int32_data_size(); i++) { |
|
|
|
value->push_back(onnx_tensor.int32_data(i)); |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < data_count; i++) { |
|
|
|
value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i])); |
|
|
|
} |
|
|
|
} |
|
|
|
break; |
|
|
|
case onnx::TensorProto_DataType_INT64: |
|
|
|
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); |
|
|
|
for (size_t i = 0; i < data_count; i++) { |
|
|
|
value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i])); |
|
|
|
if (onnx_tensor.int64_data_size() > 0) { |
|
|
|
for (int i = 0; i < onnx_tensor.int64_data_size(); i++) { |
|
|
|
value->push_back(onnx_tensor.int64_data(i)); |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < data_count; i++) { |
|
|
|
value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i])); |
|
|
|
} |
|
|
|
} |
|
|
|
break; |
|
|
|
default: |
|
|
|
|