Merge pull request !4505 from yangruoqi713/litetags/v0.7.0-beta
| @@ -16,6 +16,7 @@ | |||
| #include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| @@ -177,10 +178,22 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() { | |||
| } | |||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||
| auto input_addr = reinterpret_cast<float *>(input_tensor->Data()); | |||
| float16_t *input_addr; | |||
| if (input_tensor->data_type() == kNumberTypeFloat32) { | |||
| input_addr = | |||
| reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); | |||
| if (input_addr == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| Float32ToFloat16(reinterpret_cast<float *>(input_tensor->Data()), input_addr, input_tensor->ElementsNum()); | |||
| } else { | |||
| input_addr = reinterpret_cast<float16_t *>(input_tensor->Data()); | |||
| } | |||
| // pack input: to nhwc8 | |||
| PackNHWCFp32ToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_, | |||
| conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); | |||
| PackNHWCToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_, | |||
| conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); | |||
| ret = LiteBackendParallelLaunch(ConvDwFp16Run, this, conv_param_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| @@ -188,10 +201,13 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data()); | |||
| PackNHWC8Fp16ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, | |||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | |||
| auto output_addr = reinterpret_cast<float16_t *>(out_tensors_.at(kOutputIndex)->Data()); | |||
| PackNHWC8ToNHWCFp16(packed_output_, output_addr, conv_param_->output_batch_, | |||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | |||
| if (input_tensor->data_type() == kNumberTypeFloat32) { | |||
| context_->allocator->Free(input_addr); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -334,31 +334,57 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, | |||
| } | |||
| void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { | |||
| int c8 = UP_DIV(channel, C8NUM); | |||
| int nhwc8_batch_unit_offset = c8 * C8NUM * plane; | |||
| int nhwc8_batch_offset = 0; | |||
| int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * channel * plane; | |||
| float16_t *dst_batch = dst + b * plane * c8_channel; | |||
| float *src_batch = src + b * plane * channel; | |||
| for (int i = 0; i < plane; i++) { | |||
| float16_t *dst_plane = dst_batch + i * c8_channel; | |||
| float *src_plane = src_batch + i * channel; | |||
| for (int c = 0; c < channel; c++) { | |||
| (dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c]; | |||
| dst_plane[c] = (float16_t)(src_plane[c]); | |||
| } | |||
| } | |||
| nhwc8_batch_offset += nhwc8_batch_unit_offset; | |||
| } | |||
| } | |||
| void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) { | |||
| int c8 = UP_DIV(channel, C8NUM); | |||
| int nhwc_batch_unit_offset = channel * plane; | |||
| int nhwc_batch_offset = 0; | |||
| int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * c8 * C8NUM * plane; | |||
| float16_t *src_batch = src + b * plane * c8_channel; | |||
| float *dst_batch = dst + b * plane * channel; | |||
| for (int i = 0; i < plane; i++) { | |||
| float16_t *src_plane = src_batch + i * c8_channel; | |||
| float *dst_plane = dst_batch + i * channel; | |||
| for (int c = 0; c < channel; c++) { | |||
| (dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c]; | |||
| dst_plane[c] = (float16_t)(src_plane[c]); | |||
| } | |||
| } | |||
| nhwc_batch_offset += nhwc_batch_unit_offset; | |||
| } | |||
| } | |||
| void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) { | |||
| int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; | |||
| for (int b = 0; b < batch; b++) { | |||
| float16_t *dst_batch = dst + b * plane * c8_channel; | |||
| float16_t *src_batch = src + b * plane * channel; | |||
| for (int i = 0; i < plane; i++) { | |||
| float16_t *dst_plane = dst_batch + i * c8_channel; | |||
| float16_t *src_plane = src_batch + i * channel; | |||
| memcpy(dst_plane, src_batch, channel * sizeof(float16_t)); | |||
| } | |||
| } | |||
| } | |||
| void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) { | |||
| int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; | |||
| for (int b = 0; b < batch; b++) { | |||
| float16_t *src_batch = src + b * plane * c8_channel; | |||
| float16_t *dst_batch = dst + b * plane * channel; | |||
| for (int i = 0; i < plane; i++) { | |||
| float16_t *src_plane = src_batch + i * c8_channel; | |||
| float16_t *dst_plane = dst_batch + i * channel; | |||
| memcpy(dst_plane, src_batch, channel * sizeof(float16_t)); | |||
| } | |||
| } | |||
| } | |||
| @@ -58,6 +58,10 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, | |||
| void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); | |||
| void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); | |||
| void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||