| @@ -79,7 +79,12 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| reinterpret_cast<uint16_t *>(output_data) + offset, data_num); | reinterpret_cast<uint16_t *>(output_data) + offset, data_num); | ||||
| } else if (input_data_type == kNumberTypeInt32 && | } else if (input_data_type == kNumberTypeInt32 && | ||||
| (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { | (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { | ||||
| memcpy(output_data, input->data_c(), data_num * sizeof(int32_t)); | |||||
| memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset, | |||||
| data_num * sizeof(int32_t)); | |||||
| } else if (input_data_type == kNumberTypeFloat32 && | |||||
| (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { | |||||
| memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset, | |||||
| data_num * sizeof(float)); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -89,6 +94,7 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| case kNumberTypeBool: | case kNumberTypeBool: | ||||
| BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, | BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, | ||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | reinterpret_cast<float *>(output_data) + offset, data_num); | ||||
| break; | |||||
| case kNumberTypeUInt8: | case kNumberTypeUInt8: | ||||
| Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | ||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | reinterpret_cast<float *>(output_data) + offset, data_num); | ||||
| @@ -101,6 +107,10 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | ||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | reinterpret_cast<float *>(output_data) + offset, data_num); | ||||
| break; | break; | ||||
| case kNumberTypeFloat32: | |||||
| memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset, | |||||
| data_num * sizeof(float)); | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||