diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 56564a609b..681d2e551e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -79,7 +79,12 @@ int CastCPUKernel::DoCast(int thread_id) { reinterpret_cast(output_data) + offset, data_num); } else if (input_data_type == kNumberTypeInt32 && (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { - memcpy(output_data, input->data_c(), data_num * sizeof(int32_t)); + memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, + data_num * sizeof(int32_t)); + } else if (input_data_type == kNumberTypeFloat32 && + (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { + memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, + data_num * sizeof(float)); } else { MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR; @@ -89,6 +94,7 @@ int CastCPUKernel::DoCast(int thread_id) { case kNumberTypeBool: BoolToFloat32(reinterpret_cast(input->MutableData()) + offset, reinterpret_cast(output_data) + offset, data_num); + break; case kNumberTypeUInt8: Uint8ToFloat32(reinterpret_cast(input->MutableData()) + offset, reinterpret_cast(output_data) + offset, data_num); @@ -101,6 +107,10 @@ int CastCPUKernel::DoCast(int thread_id) { Fp16ToFloat32(reinterpret_cast(input->MutableData()) + offset, reinterpret_cast(output_data) + offset, data_num); break; + case kNumberTypeFloat32: + memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, + data_num * sizeof(float)); + break; default: MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; return RET_ERROR;