Browse Source

fix cast datatype bugs

tags/v1.1.0
zengxianglong 5 years ago
parent
commit
7d24a0b0bc
1 changed files with 11 additions and 1 deletions
  1. +11
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc

+ 11
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc View File

@@ -79,7 +79,12 @@ int CastCPUKernel::DoCast(int thread_id) {
reinterpret_cast<uint16_t *>(output_data) + offset, data_num); reinterpret_cast<uint16_t *>(output_data) + offset, data_num);
} else if (input_data_type == kNumberTypeInt32 && } else if (input_data_type == kNumberTypeInt32 &&
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
memcpy(output_data, input->data_c(), data_num * sizeof(int32_t));
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset,
data_num * sizeof(int32_t));
} else if (input_data_type == kNumberTypeFloat32 &&
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset,
data_num * sizeof(float));
} else { } else {
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type;
return RET_ERROR; return RET_ERROR;
@@ -89,6 +94,7 @@ int CastCPUKernel::DoCast(int thread_id) {
case kNumberTypeBool: case kNumberTypeBool:
BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num); reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeUInt8: case kNumberTypeUInt8:
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num); reinterpret_cast<float *>(output_data) + offset, data_num);
@@ -101,6 +107,10 @@ int CastCPUKernel::DoCast(int thread_id) {
Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num); reinterpret_cast<float *>(output_data) + offset, data_num);
break; break;
case kNumberTypeFloat32:
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset,
data_num * sizeof(float));
break;
default: default:
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR; return RET_ERROR;


Loading…
Cancel
Save