|
|
|
@@ -79,7 +79,12 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
reinterpret_cast<uint16_t *>(output_data) + offset, data_num); |
|
|
|
} else if (input_data_type == kNumberTypeInt32 && |
|
|
|
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { |
|
|
|
memcpy(output_data, input->data_c(), data_num * sizeof(int32_t)); |
|
|
|
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset, |
|
|
|
data_num * sizeof(int32_t)); |
|
|
|
} else if (input_data_type == kNumberTypeFloat32 && |
|
|
|
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { |
|
|
|
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset, |
|
|
|
data_num * sizeof(float)); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -89,6 +94,7 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
case kNumberTypeBool: |
|
|
|
BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, |
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num); |
|
|
|
break; |
|
|
|
case kNumberTypeUInt8: |
|
|
|
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, |
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num); |
|
|
|
@@ -101,6 +107,10 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, |
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num); |
|
|
|
break; |
|
|
|
case kNumberTypeFloat32: |
|
|
|
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset, |
|
|
|
data_num * sizeof(float)); |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; |
|
|
|
return RET_ERROR; |
|
|
|
|