Browse Source

!283 perfect onnx data type values

Merge pull request !283 from yangyongqiang/master
pull/283/MERGE
i-robot Gitee 4 years ago
parent
commit
0bde97c8d3
2 changed files with 69 additions and 16 deletions
  1. +36
    -10
      parser/onnx/onnx_constant_parser.cc
  2. +33
    -6
      parser/onnx/onnx_constant_parser.h

+ 36
- 10
parser/onnx/onnx_constant_parser.cc View File

@@ -46,22 +46,27 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_
}

std::map<uint32_t, int32_t> datatype_val_size_map = {
// for int32, uint8, int8, uint16, int16, bool, and float16 values
{OnnxDataType::INT32, tensor_proto.int32_data_size()},
{OnnxDataType::UINT8, tensor_proto.int32_data_size()},
{OnnxDataType::INT8, tensor_proto.int32_data_size()},
{OnnxDataType::UINT16, tensor_proto.int32_data_size()},
{OnnxDataType::INT16, tensor_proto.int32_data_size()},
{OnnxDataType::BOOL, tensor_proto.int32_data_size()},
{OnnxDataType::FLOAT16, tensor_proto.int32_data_size()},
// for int64 values
{OnnxDataType::INT64, tensor_proto.int64_data_size()},
// for string values
{OnnxDataType::STRING, tensor_proto.string_data_size()},
// for float and complex64 values
{OnnxDataType::FLOAT, tensor_proto.float_data_size()},
{OnnxDataType::COMPLEX64, tensor_proto.float_data_size()},
// for double and complex128 values
{OnnxDataType::DOUBLE, tensor_proto.double_data_size()},
{OnnxDataType::COMPLEX128, tensor_proto.double_data_size()},
// for uint64 and uint32 values
{OnnxDataType::UINT64, tensor_proto.uint64_data_size()},
{OnnxDataType::UINT8, 0},
{OnnxDataType::INT8, 0},
{OnnxDataType::UINT16, 0},
{OnnxDataType::INT16, 0},
{OnnxDataType::BOOL, 0},
{OnnxDataType::FLOAT16, 0},
{OnnxDataType::UINT32, 0},
{OnnxDataType::COMPLEX64, 0},
{OnnxDataType::COMPLEX128, 0},
{OnnxDataType::BFLOAT16, 0},
{OnnxDataType::UINT32, tensor_proto.uint64_data_size()},
};

int32_t datatype_val_size = 0;
@@ -98,12 +103,21 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_
void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor,
int count, int64_t data_type) {
switch (data_type) {
// for int32, uint8, int8, uint16, int16, bool, and float16 values
case OnnxDataType::INT32:
case OnnxDataType::UINT8:
case OnnxDataType::INT8:
case OnnxDataType::UINT16:
case OnnxDataType::INT16:
case OnnxDataType::BOOL:
case OnnxDataType::FLOAT16:
(void)SetTensorData(tensor_proto.int32_data_size(), tensor_proto.int32_data(), count, tensor);
break;
// for int64 values
case OnnxDataType::INT64:
(void)SetTensorData(tensor_proto.int64_data_size(), tensor_proto.int64_data(), count, tensor);
break;
// for string values
case OnnxDataType::STRING: {
std::vector<std::string> data;
for (auto str_data : tensor_proto.string_data()) {
@@ -112,13 +126,25 @@ void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &t
tensor.SetData(data);
break;
}
// for float and complex64 values
case OnnxDataType::FLOAT:
(void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(), count, tensor);
break;
case OnnxDataType::COMPLEX64:
(void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(),
tensor_proto.float_data_size(), tensor);
break;
// for double and complex128 values
case OnnxDataType::DOUBLE:
(void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(), count, tensor);
break;
case OnnxDataType::COMPLEX128:
(void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(),
tensor_proto.double_data_size(), tensor);
break;
// for uint64 and uint32 values
case OnnxDataType::UINT64:
case OnnxDataType::UINT32:
(void)SetTensorData(tensor_proto.uint64_data_size(), tensor_proto.uint64_data(), count, tensor);
break;
default:


+ 33
- 6
parser/onnx/onnx_constant_parser.h View File

@@ -41,23 +41,50 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {
static Status SetTensorData(int32_t val_size, const google::protobuf::RepeatedField<T> &val_vector, int count,
Tensor &tensor) {
bool zeros_like = (count != val_size && val_size == 1);
T *addr = new (std::nothrow) T[count]();
unique_ptr<T> addr(new(std::nothrow) T[count]());
GE_CHECK_NOTNULL(addr);
int minCount = (count > val_size) ? val_size : count;
if (!zeros_like) {
for (int32_t i = 0; i < minCount; i++) {
*(addr + i) = val_vector.Get(i);
*(addr.get() + i) = val_vector.Get(i);
}
for (int32_t i = minCount; i < count; i++) {
*(addr + i) = val_vector.Get(minCount - 1);
*(addr.get() + i) = val_vector.Get(minCount - 1);
}
} else {
for (int32_t i = 0; i < count; i++) {
*(addr + i) = val_vector.Get(0);
*(addr.get() + i) = val_vector.Get(0);
}
}

DataType data_type = tensor.GetTensorDesc().GetDataType();
switch (data_type) {
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \
case dt_type: \
{ \
unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \
GE_CHECK_NOTNULL(addr_trans); \
for (int32_t i = 0; i < count; i++) { \
*(addr_trans.get() + i) = static_cast<value_type>(*(addr.get() + i)); \
} \
tensor.SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), count * sizeof(value_type)); \
break; \
} \

CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor)
CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor)
CASE_SET_DATA(DT_INT8, int8_t, addr, count, tensor)
CASE_SET_DATA(DT_UINT16, uint16_t, addr, count, tensor)
CASE_SET_DATA(DT_UINT8, uint8_t, addr, count, tensor)
CASE_SET_DATA(DT_BOOL, bool, addr, count, tensor)
CASE_SET_DATA(DT_UINT32, uint32_t, addr, count, tensor)
#undef CASE_SET_DATA
default:
{
tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(T));
break;
}
}
tensor.SetData(reinterpret_cast<uint8_t *>(addr), count * sizeof(T));
GE_DELETE_NEW_ARRAY(addr);
return SUCCESS;
}
};


Loading…
Cancel
Save