| @@ -17,6 +17,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/tensor.h" | |||||
| #include "nnacl/fp32/cast.h" | #include "nnacl/fp32/cast.h" | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| @@ -70,6 +71,12 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| MS_ASSERT(output_data != nullptr); | MS_ASSERT(output_data != nullptr); | ||||
| auto input_data_type = input->data_type(); | auto input_data_type = input->data_type(); | ||||
| auto output_data_type = output->data_type(); | auto output_data_type = output->data_type(); | ||||
| if (input_data_type == output_data_type) { | |||||
| auto datalen = lite::DataTypeSize(input_data_type); | |||||
| memcpy(reinterpret_cast<char *>(output_data) + offset * datalen, | |||||
| reinterpret_cast<char *>(input->data_c()) + offset * datalen, data_num * datalen); | |||||
| return RET_OK; | |||||
| } | |||||
| if (output_data_type != kNumberTypeFloat32) { | if (output_data_type != kNumberTypeFloat32) { | ||||
| if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { | if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { | ||||
| Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset, | Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset, | ||||
| @@ -83,9 +90,6 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { | } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { | ||||
| Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset, | Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset, | ||||
| reinterpret_cast<int64_t *>(output_data) + offset, data_num); | reinterpret_cast<int64_t *>(output_data) + offset, data_num); | ||||
| } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt32) { | |||||
| memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset, | |||||
| data_num * sizeof(int32_t)); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -108,10 +112,6 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | ||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | reinterpret_cast<float *>(output_data) + offset, data_num); | ||||
| break; | break; | ||||
| case kNumberTypeFloat32: | |||||
| memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset, | |||||
| data_num * sizeof(float)); | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||