|
|
|
@@ -71,13 +71,18 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
auto input_data_type = input->data_type(); |
|
|
|
auto output_data_type = output->data_type(); |
|
|
|
if (output_data_type != kNumberTypeFloat32) { |
|
|
|
if (input_data_type == kNumberTypeFloat32 && |
|
|
|
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { |
|
|
|
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { |
|
|
|
Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset, |
|
|
|
reinterpret_cast<int64_t *>(output_data) + offset, data_num); |
|
|
|
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { |
|
|
|
Float32ToInt32(reinterpret_cast<float *>(input->data_c()) + offset, |
|
|
|
reinterpret_cast<int32_t *>(output_data) + offset, data_num); |
|
|
|
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) { |
|
|
|
Float32ToFp16(reinterpret_cast<float *>(input->data_c()) + offset, |
|
|
|
reinterpret_cast<uint16_t *>(output_data) + offset, data_num); |
|
|
|
} else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { |
|
|
|
Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset, |
|
|
|
reinterpret_cast<int64_t *>(output_data) + offset, data_num); |
|
|
|
} else if (input_data_type == kNumberTypeInt32 && |
|
|
|
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { |
|
|
|
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset, |
|
|
|
|