|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
#include <vector> |
|
|
|
#include "schema/model_generated.h" |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
#include "src/tensor.h" |
|
|
|
#include "nnacl/fp32/cast.h" |
|
|
|
#include "nnacl/op_base.h" |
|
|
|
#include "src/runtime/runtime_api.h" |
|
|
|
@@ -70,6 +71,12 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
MS_ASSERT(output_data != nullptr); |
|
|
|
auto input_data_type = input->data_type(); |
|
|
|
auto output_data_type = output->data_type(); |
|
|
|
if (input_data_type == output_data_type) { |
|
|
|
auto datalen = lite::DataTypeSize(input_data_type); |
|
|
|
memcpy(reinterpret_cast<char *>(output_data) + offset * datalen, |
|
|
|
reinterpret_cast<char *>(input->data_c()) + offset * datalen, data_num * datalen); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
if (output_data_type != kNumberTypeFloat32) { |
|
|
|
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { |
|
|
|
Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset, |
|
|
|
@@ -83,9 +90,6 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
} else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { |
|
|
|
Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset, |
|
|
|
reinterpret_cast<int64_t *>(output_data) + offset, data_num); |
|
|
|
} else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt32) { |
|
|
|
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset, |
|
|
|
data_num * sizeof(int32_t)); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -108,10 +112,6 @@ int CastCPUKernel::DoCast(int thread_id) { |
|
|
|
Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, |
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num); |
|
|
|
break; |
|
|
|
case kNumberTypeFloat32: |
|
|
|
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset, |
|
|
|
data_num * sizeof(float)); |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; |
|
|
|
return RET_ERROR; |
|
|
|
|