|
|
|
@@ -46,22 +46,27 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ |
|
|
|
} |
|
|
|
|
|
|
|
std::map<uint32_t, int32_t> datatype_val_size_map = { |
|
|
|
// for int32, uint8, int8, uint16, int16, bool, and float16 values |
|
|
|
{OnnxDataType::INT32, tensor_proto.int32_data_size()}, |
|
|
|
{OnnxDataType::UINT8, tensor_proto.int32_data_size()}, |
|
|
|
{OnnxDataType::INT8, tensor_proto.int32_data_size()}, |
|
|
|
{OnnxDataType::UINT16, tensor_proto.int32_data_size()}, |
|
|
|
{OnnxDataType::INT16, tensor_proto.int32_data_size()}, |
|
|
|
{OnnxDataType::BOOL, tensor_proto.int32_data_size()}, |
|
|
|
{OnnxDataType::FLOAT16, tensor_proto.int32_data_size()}, |
|
|
|
// for int64 values |
|
|
|
{OnnxDataType::INT64, tensor_proto.int64_data_size()}, |
|
|
|
// for string values |
|
|
|
{OnnxDataType::STRING, tensor_proto.string_data_size()}, |
|
|
|
// for float and complex64 values |
|
|
|
{OnnxDataType::FLOAT, tensor_proto.float_data_size()}, |
|
|
|
{OnnxDataType::COMPLEX64, tensor_proto.float_data_size()}, |
|
|
|
// for double and complex128 values |
|
|
|
{OnnxDataType::DOUBLE, tensor_proto.double_data_size()}, |
|
|
|
{OnnxDataType::COMPLEX128, tensor_proto.double_data_size()}, |
|
|
|
// for uint64 and uint32 values |
|
|
|
{OnnxDataType::UINT64, tensor_proto.uint64_data_size()}, |
|
|
|
{OnnxDataType::UINT8, 0}, |
|
|
|
{OnnxDataType::INT8, 0}, |
|
|
|
{OnnxDataType::UINT16, 0}, |
|
|
|
{OnnxDataType::INT16, 0}, |
|
|
|
{OnnxDataType::BOOL, 0}, |
|
|
|
{OnnxDataType::FLOAT16, 0}, |
|
|
|
{OnnxDataType::UINT32, 0}, |
|
|
|
{OnnxDataType::COMPLEX64, 0}, |
|
|
|
{OnnxDataType::COMPLEX128, 0}, |
|
|
|
{OnnxDataType::BFLOAT16, 0}, |
|
|
|
{OnnxDataType::UINT32, tensor_proto.uint64_data_size()}, |
|
|
|
}; |
|
|
|
|
|
|
|
int32_t datatype_val_size = 0; |
|
|
|
@@ -98,12 +103,21 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ |
|
|
|
void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, |
|
|
|
int count, int64_t data_type) { |
|
|
|
switch (data_type) { |
|
|
|
// for int32, uint8, int8, uint16, int16, bool, and float16 values |
|
|
|
case OnnxDataType::INT32: |
|
|
|
case OnnxDataType::UINT8: |
|
|
|
case OnnxDataType::INT8: |
|
|
|
case OnnxDataType::UINT16: |
|
|
|
case OnnxDataType::INT16: |
|
|
|
case OnnxDataType::BOOL: |
|
|
|
case OnnxDataType::FLOAT16: |
|
|
|
(void)SetTensorData(tensor_proto.int32_data_size(), tensor_proto.int32_data(), count, tensor); |
|
|
|
break; |
|
|
|
// for int64 values |
|
|
|
case OnnxDataType::INT64: |
|
|
|
(void)SetTensorData(tensor_proto.int64_data_size(), tensor_proto.int64_data(), count, tensor); |
|
|
|
break; |
|
|
|
// for string values |
|
|
|
case OnnxDataType::STRING: { |
|
|
|
std::vector<std::string> data; |
|
|
|
for (auto str_data : tensor_proto.string_data()) { |
|
|
|
@@ -112,13 +126,25 @@ void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &t |
|
|
|
tensor.SetData(data); |
|
|
|
break; |
|
|
|
} |
|
|
|
// for float and complex64 values |
|
|
|
case OnnxDataType::FLOAT: |
|
|
|
(void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(), count, tensor); |
|
|
|
break; |
|
|
|
case OnnxDataType::COMPLEX64: |
|
|
|
(void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(), |
|
|
|
tensor_proto.float_data_size(), tensor); |
|
|
|
break; |
|
|
|
// for double and complex128 values |
|
|
|
case OnnxDataType::DOUBLE: |
|
|
|
(void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(), count, tensor); |
|
|
|
break; |
|
|
|
case OnnxDataType::COMPLEX128: |
|
|
|
(void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(), |
|
|
|
tensor_proto.double_data_size(), tensor); |
|
|
|
break; |
|
|
|
// for uint64 and uint32 values |
|
|
|
case OnnxDataType::UINT64: |
|
|
|
case OnnxDataType::UINT32: |
|
|
|
(void)SetTensorData(tensor_proto.uint64_data_size(), tensor_proto.uint64_data(), count, tensor); |
|
|
|
break; |
|
|
|
default: |
|
|
|
|