From 275c4cb59ca8aaa412c697e3188e1f3d564e833d Mon Sep 17 00:00:00 2001 From: y00500818 Date: Tue, 6 Apr 2021 20:10:57 +0800 Subject: [PATCH] perfect onnx data type values --- parser/onnx/onnx_constant_parser.cc | 46 ++++++++++++++++++++++------- parser/onnx/onnx_constant_parser.h | 39 ++++++++++++++++++++---- 2 files changed, 69 insertions(+), 16 deletions(-) diff --git a/parser/onnx/onnx_constant_parser.cc b/parser/onnx/onnx_constant_parser.cc index 9b9b0f1..0bc8d97 100644 --- a/parser/onnx/onnx_constant_parser.cc +++ b/parser/onnx/onnx_constant_parser.cc @@ -46,22 +46,27 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ } std::map datatype_val_size_map = { + // for int32, uint8, int8, uint16, int16, bool, and float16 values {OnnxDataType::INT32, tensor_proto.int32_data_size()}, + {OnnxDataType::UINT8, tensor_proto.int32_data_size()}, + {OnnxDataType::INT8, tensor_proto.int32_data_size()}, + {OnnxDataType::UINT16, tensor_proto.int32_data_size()}, + {OnnxDataType::INT16, tensor_proto.int32_data_size()}, + {OnnxDataType::BOOL, tensor_proto.int32_data_size()}, + {OnnxDataType::FLOAT16, tensor_proto.int32_data_size()}, + // for int64 values {OnnxDataType::INT64, tensor_proto.int64_data_size()}, + // for string values {OnnxDataType::STRING, tensor_proto.string_data_size()}, + // for float and complex64 values {OnnxDataType::FLOAT, tensor_proto.float_data_size()}, + {OnnxDataType::COMPLEX64, tensor_proto.float_data_size()}, + // for double and complex128 values {OnnxDataType::DOUBLE, tensor_proto.double_data_size()}, + {OnnxDataType::COMPLEX128, tensor_proto.double_data_size()}, + // for uint64 and uint32 values {OnnxDataType::UINT64, tensor_proto.uint64_data_size()}, - {OnnxDataType::UINT8, 0}, - {OnnxDataType::INT8, 0}, - {OnnxDataType::UINT16, 0}, - {OnnxDataType::INT16, 0}, - {OnnxDataType::BOOL, 0}, - {OnnxDataType::FLOAT16, 0}, - {OnnxDataType::UINT32, 0}, - {OnnxDataType::COMPLEX64, 0}, - {OnnxDataType::COMPLEX128, 0}, - {OnnxDataType::BFLOAT16, 0}, + {OnnxDataType::UINT32, tensor_proto.uint64_data_size()}, }; int32_t datatype_val_size = 0; @@ -98,12 +103,21 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count, int64_t data_type) { switch (data_type) { + // for int32, uint8, int8, uint16, int16, bool, and float16 values case OnnxDataType::INT32: + case OnnxDataType::UINT8: + case OnnxDataType::INT8: + case OnnxDataType::UINT16: + case OnnxDataType::INT16: + case OnnxDataType::BOOL: + case OnnxDataType::FLOAT16: (void)SetTensorData(tensor_proto.int32_data_size(), tensor_proto.int32_data(), count, tensor); break; + // for int64 values case OnnxDataType::INT64: (void)SetTensorData(tensor_proto.int64_data_size(), tensor_proto.int64_data(), count, tensor); break; + // for string values case OnnxDataType::STRING: { std::vector data; for (auto str_data : tensor_proto.string_data()) { @@ -112,13 +126,25 @@ void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &t tensor.SetData(data); break; } + // for float and complex64 values case OnnxDataType::FLOAT: (void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(), count, tensor); break; + case OnnxDataType::COMPLEX64: + (void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(), + tensor_proto.float_data_size(), tensor); + break; + // for double and complex128 values case OnnxDataType::DOUBLE: (void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(), count, tensor); break; + case OnnxDataType::COMPLEX128: + (void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(), + tensor_proto.double_data_size(), tensor); + break; + // for uint64 and uint32 values case OnnxDataType::UINT64: + case OnnxDataType::UINT32: (void)SetTensorData(tensor_proto.uint64_data_size(), tensor_proto.uint64_data(), count, tensor); break; default: diff --git a/parser/onnx/onnx_constant_parser.h b/parser/onnx/onnx_constant_parser.h index 571fae2..1d44fa6 100644 --- a/parser/onnx/onnx_constant_parser.h +++ b/parser/onnx/onnx_constant_parser.h @@ -41,23 +41,50 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { static Status SetTensorData(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, Tensor &tensor) { bool zeros_like = (count != val_size && val_size == 1); - T *addr = new (std::nothrow) T[count](); + unique_ptr addr(new(std::nothrow) T[count]()); GE_CHECK_NOTNULL(addr); int minCount = (count > val_size) ? val_size : count; if (!zeros_like) { for (int32_t i = 0; i < minCount; i++) { - *(addr + i) = val_vector.Get(i); + *(addr.get() + i) = val_vector.Get(i); } for (int32_t i = minCount; i < count; i++) { - *(addr + i) = val_vector.Get(minCount - 1); + *(addr.get() + i) = val_vector.Get(minCount - 1); } } else { for (int32_t i = 0; i < count; i++) { - *(addr + i) = val_vector.Get(0); + *(addr.get() + i) = val_vector.Get(0); + } + } + + DataType data_type = tensor.GetTensorDesc().GetDataType(); + switch (data_type) { +#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ + case dt_type: \ + { \ + unique_ptr addr_trans(new(std::nothrow) value_type[count]()); \ + GE_CHECK_NOTNULL(addr_trans); \ + for (int32_t i = 0; i < count; i++) { \ + *(addr_trans.get() + i) = static_cast(*(addr.get() + i)); \ + } \ + tensor.SetData(reinterpret_cast(addr_trans.get()), count * sizeof(value_type)); \ + break; \ + } \ + + CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) + CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) + CASE_SET_DATA(DT_INT8, int8_t, addr, count, tensor) + CASE_SET_DATA(DT_UINT16, uint16_t, addr, count, tensor) + CASE_SET_DATA(DT_UINT8, uint8_t, addr, count, tensor) + CASE_SET_DATA(DT_BOOL, bool, addr, count, tensor) + CASE_SET_DATA(DT_UINT32, uint32_t, addr, count, tensor) +#undef CASE_SET_DATA + default: + { + tensor.SetData(reinterpret_cast(addr.get()), count * sizeof(T)); + break; } } - tensor.SetData(reinterpret_cast(addr), count * sizeof(T)); - GE_DELETE_NEW_ARRAY(addr); return SUCCESS; } };