|
|
|
@@ -65,17 +65,32 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
} |
|
|
|
|
|
|
|
auto offset = thread_id * stride_; |
|
|
|
auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->Data()); |
|
|
|
switch (input->data_type()) { |
|
|
|
auto output = out_tensors_.at(0); |
|
|
|
auto output_data = output->Data(); |
|
|
|
auto input_data_type = input->data_type(); |
|
|
|
auto output_data_type = output->data_type(); |
|
|
|
if (output_data_type != kNumberTypeFloat32) { |
|
|
|
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { |
|
|
|
Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset, |
|
|
|
reinterpret_cast<int32_t *>(output_data) + offset, data_num); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} else { |
|
|
|
switch (input_data_type) { |
|
|
|
case kNumberTypeUInt8: |
|
|
|
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset, output_data + offset, data_num); |
|
|
|
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset, |
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num); |
|
|
|
break; |
|
|
|
case kNumberTypeInt32: |
|
|
|
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset, output_data + offset, data_num); |
|
|
|
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset, |
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num); |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); |
|
|
|
MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|