Merge pull request !4575 from fuzhiye/tmptags/v0.7.0-beta
| @@ -247,36 +247,17 @@ int Convolution3x3FP16CPUKernel::Run() { | |||||
| } | } | ||||
| // get real output | // get real output | ||||
| // todo | |||||
| int out_w_block = UP_DIV(conv_param_->output_w_, C4NUM); | |||||
| int out_h_block = UP_DIV(conv_param_->output_h_, C4NUM); | |||||
| int oc8 = UP_DIV(conv_param_->output_channel_, C8NUM); | |||||
| bool relu = conv_param_->is_relu_; | bool relu = conv_param_->is_relu_; | ||||
| bool relu6 = conv_param_->is_relu6_; | bool relu6 = conv_param_->is_relu6_; | ||||
| for (int batch = 0; batch < conv_param_->output_batch_; batch++) { | |||||
| int tmp_out_batch_offset = | |||||
| batch * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; | |||||
| int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_; | |||||
| const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset; | |||||
| float16_t *batch_out = execute_output_ + ro_batch_size; | |||||
| for (int h = 0; h < conv_param_->output_h_; h++) { | |||||
| for (int w = 0; w < conv_param_->output_w_; w++) { | |||||
| for (int c = 0; c < conv_param_->output_channel_; c++) { | |||||
| int oc8_block = c / C8NUM; | |||||
| int oc8_res = c % C8NUM; | |||||
| int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + | |||||
| C8NUM * (h * out_w_block * C4NUM + w) + oc8_res; | |||||
| int dst_offset = (h * conv_param_->output_w_ + w) * conv_param_->output_channel_ + c; | |||||
| (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; | |||||
| if (relu) { | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||||
| } else if (relu6) { | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] > 6 ? 6 : (batch_out + dst_offset)[0]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (relu) { | |||||
| UnPack3x3ReluOutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, | |||||
| conv_param_->output_w_, conv_param_->output_channel_); | |||||
| } else if (relu6) { | |||||
| UnPack3x3Relu6OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, | |||||
| conv_param_->output_w_, conv_param_->output_channel_); | |||||
| } else { | |||||
| UnPack3x3OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, | |||||
| conv_param_->output_w_, conv_param_->output_channel_); | |||||
| } | } | ||||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | ConvolutionBaseFP16CPUKernel::IfCastOutput(); | ||||
| @@ -31,8 +31,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | public: | ||||
| ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | ||||
| const lite::Primitive *primitive) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| const lite::Primitive *primitive, int out_unit) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), output_unit_(out_unit) {} | |||||
| ~ConvolutionWinogradFP16CPUKernel() override { | ~ConvolutionWinogradFP16CPUKernel() override { | ||||
| if (fp16_weight_ != nullptr) { | if (fp16_weight_ != nullptr) { | ||||
| free(fp16_weight_); | free(fp16_weight_); | ||||
| @@ -42,13 +42,13 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||||
| int ic4 = UP_DIV(in_channel, C4NUM); | int ic4 = UP_DIV(in_channel, C4NUM); | ||||
| int kernel_plane = kernel_h * kernel_w; | int kernel_plane = kernel_h * kernel_w; | ||||
| int oc_block, oc_block_num; | int oc_block, oc_block_num; | ||||
| // #ifdef ENABLE_ARM32 | |||||
| // oc_block = C4NUM; | |||||
| // oc_block_num = UP_DIV(out_channel, C4NUM); | |||||
| // #else | |||||
| // #ifdef ENABLE_ARM32 | |||||
| // oc_block = C4NUM; | |||||
| // oc_block_num = UP_DIV(out_channel, C4NUM); | |||||
| // #else | |||||
| oc_block = C8NUM; | oc_block = C8NUM; | ||||
| oc_block_num = UP_DIV(out_channel, C8NUM); | oc_block_num = UP_DIV(out_channel, C8NUM); | ||||
| // #endif | |||||
| // #endif | |||||
| int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane; | int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane; | ||||
| // init weight | // init weight | ||||
| @@ -123,18 +123,11 @@ void ConvolutionCPUKernel::ConfigInputOutput() { | |||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | auto output_tensor = out_tensors_.at(kOutputIndex); | ||||
| output_tensor->SetFormat(schema::Format_NHWC); | output_tensor->SetFormat(schema::Format_NHWC); | ||||
| // select trans func for input | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto ret = CheckLayout(input_tensor); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Check layout failed."; | |||||
| return; | |||||
| } | |||||
| // #ifdef ENABLE_ARM32 | |||||
| // gemm_func_ = IndirectGemmFp32_8x4; | |||||
| // #else | |||||
| // #ifdef ENABLE_ARM32 | |||||
| // gemm_func_ = IndirectGemmFp32_8x4; | |||||
| // #else | |||||
| gemm_func_ = IndirectGemmFp32_8x8; | gemm_func_ = IndirectGemmFp32_8x8; | ||||
| // #endif | |||||
| // #endif | |||||
| } | } | ||||
| int ConvolutionCPUKernel::Init() { | int ConvolutionCPUKernel::Init() { | ||||
| @@ -221,7 +214,7 @@ int ConvolutionCPUKernel::Run() { | |||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(ConvolutionImpl, this, thread_count_); | int error_code = LiteBackendParallelLaunch(ConvolutionImpl, this, thread_count_); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| @@ -54,13 +54,13 @@ int Convolution3x3CPUKernel::InitWeightBias() { | |||||
| int iC4 = UP_DIV(input_channel, C4NUM); | int iC4 = UP_DIV(input_channel, C4NUM); | ||||
| int oC4 = UP_DIV(output_channel, C4NUM); | int oC4 = UP_DIV(output_channel, C4NUM); | ||||
| int oc_block, oc_block_num; | int oc_block, oc_block_num; | ||||
| // #ifdef ENABLE_ARM32 | |||||
| // oc_block = C4NUM; | |||||
| // oc_block_num = UP_DIV(output_channel, C4NUM); | |||||
| // #else | |||||
| // #ifdef ENABLE_ARM32 | |||||
| // oc_block = C4NUM; | |||||
| // oc_block_num = UP_DIV(output_channel, C4NUM); | |||||
| // #else | |||||
| oc_block = C8NUM; | oc_block = C8NUM; | ||||
| oc_block_num = UP_DIV(output_channel, C8NUM); | oc_block_num = UP_DIV(output_channel, C8NUM); | ||||
| // #endif | |||||
| // #endif | |||||
| const int k_plane = 16; | const int k_plane = 16; | ||||
| // init weight | // init weight | ||||
| size_t transformed_size = iC4 * C4NUM * oc_block_num * oc_block * k_plane * sizeof(float); | size_t transformed_size = iC4 * C4NUM * oc_block_num * oc_block * k_plane * sizeof(float); | ||||
| @@ -151,18 +151,11 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| void Convolution3x3CPUKernel::ConfigInputOutput() { | void Convolution3x3CPUKernel::ConfigInputOutput() { | ||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | auto output_tensor = out_tensors_.at(kOutputIndex); | ||||
| output_tensor->SetFormat(schema::Format_NHWC); | output_tensor->SetFormat(schema::Format_NHWC); | ||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto ret = CheckLayout(input_tensor); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Check layout failed."; | |||||
| return; | |||||
| } | |||||
| // #ifdef ENABLE_ARM32 | |||||
| // gemm_func_ = IndirectGemmFp32_8x4; | |||||
| // #else | |||||
| // #ifdef ENABLE_ARM32 | |||||
| // gemm_func_ = IndirectGemmFp32_8x4; | |||||
| // #else | |||||
| gemm_func_ = IndirectGemmFp32_8x8; | gemm_func_ = IndirectGemmFp32_8x8; | ||||
| // #endif | |||||
| // #endif | |||||
| } | } | ||||
| int Convolution3x3CPUKernel::Init() { | int Convolution3x3CPUKernel::Init() { | ||||
| @@ -252,7 +245,7 @@ int Convolution3x3CPUKernel::Run() { | |||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(Convolution3x3Impl, this, thread_count_); | int error_code = LiteBackendParallelLaunch(Convolution3x3Impl, this, thread_count_); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| @@ -104,14 +104,6 @@ void ConvolutionSWCPUKernel::ConfigInputOutput() { | |||||
| // set output format | // set output format | ||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | auto output_tensor = out_tensors_.at(kOutputIndex); | ||||
| output_tensor->SetFormat(schema::Format_NHWC); | output_tensor->SetFormat(schema::Format_NHWC); | ||||
| // select trans func for input | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto ret = CheckLayout(input_tensor); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Check layout failed."; | |||||
| return; | |||||
| } | |||||
| } | } | ||||
| int ConvolutionSWCPUKernel::Init() { | int ConvolutionSWCPUKernel::Init() { | ||||
| @@ -199,7 +191,7 @@ int ConvolutionSWCPUKernel::Run() { | |||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(ConvolutionSWImpl, this, thread_count_); | int error_code = LiteBackendParallelLaunch(ConvolutionSWImpl, this, thread_count_); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| @@ -222,12 +222,6 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| int ConvolutionWinogradCPUKernel::ConfigInputOutput() { | int ConvolutionWinogradCPUKernel::ConfigInputOutput() { | ||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto ret = CheckLayout(input_tensor); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Check layout failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | auto output_tensor = out_tensors_.at(kOutputIndex); | ||||
| output_tensor->SetFormat(schema::Format_NHWC); | output_tensor->SetFormat(schema::Format_NHWC); | ||||
| @@ -357,7 +351,7 @@ int ConvolutionWinogradCPUKernel::Run() { | |||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(ConvolutionWinogradImpl, this, thread_count_); | int error_code = LiteBackendParallelLaunch(ConvolutionWinogradImpl, this, thread_count_); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| @@ -35,14 +35,14 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh | |||||
| size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, | size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, | ||||
| size_t relu6) { | size_t relu6) { | ||||
| if (!(mode && writeC8)) { | if (!(mode && writeC8)) { | ||||
| IndirectGemmFp16_16x8_common(output, input, weight, bias, step, ic4, output, offset, relu, relu6); | |||||
| IndirectGemmFp16_16x8_common(output, input, weight, bias, step, ic4, out_channel, offset, relu, relu6); | |||||
| } else { | } else { | ||||
| IndirectGemmFp16_16x8_c8(output, input, weight, bias, step, ic4, output, offset, mode, writeC8, relu, relu6); | |||||
| IndirectGemmFp16_16x8_c8(output, input, weight, bias, step, ic4, out_channel, offset, mode, writeC8, relu, relu6); | |||||
| } | } | ||||
| } | } | ||||
| void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | ||||
| size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6) { | |||||
| size_t ic4, size_t out_channel, size_t offset, size_t relu, size_t relu6) { | |||||
| const int tile_n = 16; | const int tile_n = 16; | ||||
| for (int i = 0; i < out_channel; i++) { | for (int i = 0; i < out_channel; i++) { | ||||
| int oc8_block = i / C8NUM; | int oc8_block = i / C8NUM; | ||||
| @@ -74,7 +74,7 @@ void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t | |||||
| if (relu) { | if (relu) { | ||||
| tmp[0] = tmp[0] < 0 ? 0 : tmp[0]; | tmp[0] = tmp[0] < 0 ? 0 : tmp[0]; | ||||
| } else if (relu6) { | } else if (relu6) { | ||||
| mp[0] = tmp[0] < 0 ? 0 : tmp[0]; | |||||
| tmp[0] = tmp[0] < 0 ? 0 : tmp[0]; | |||||
| tmp[0] = tmp[0] > 6 ? 6 : tmp[0]; | tmp[0] = tmp[0] > 6 ? 6 : tmp[0]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -415,6 +415,124 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||||
| } | } | ||||
| } | } | ||||
| void UnPack3x3OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) { | |||||
| int out_w_block = UP_DIV(width, C4NUM); | |||||
| int out_h_block = UP_DIV(height, C4NUM); | |||||
| int oc8 = UP_DIV(channel, C8NUM); | |||||
| for (int b = 0; b < batch; b++) { | |||||
| int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; | |||||
| int ro_batch_size = b * channel * height * width; | |||||
| const float16_t *batch_tmp_out = src + tmp_out_batch_offset; | |||||
| float16_t *batch_out = dst + ro_batch_size; | |||||
| for (int h = 0; h < height; h++) { | |||||
| int src_h_offset = h * out_w_block * C4NUM * C8NUM; | |||||
| int dst_h_offset = h * width * channel; | |||||
| for (int w = 0; w < width; w++) { | |||||
| int src_w_offset = src_h_offset + w * C8NUM; | |||||
| int dst_w_offset = dst_h_offset + w * channel; | |||||
| for (int c = 0; c < oc8 - 1; ++c) { | |||||
| int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset; | |||||
| int dst_offset = dst_w_offset + c * C8NUM; | |||||
| vst1q_f16(batch_out + dst_offset, vld1q_f16(batch_tmp_out + src_offset)); | |||||
| } | |||||
| int c_res = channel - (oc8 - 1) * C8NUM; | |||||
| int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; | |||||
| int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM; | |||||
| for (int c = 0; c < c_res; c++) { | |||||
| int src_offset = src_c_res_offset + c; | |||||
| int dst_offset = dst_c_res_offset + c; | |||||
| (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void UnPack3x3ReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) { | |||||
| int out_w_block = UP_DIV(width, C4NUM); | |||||
| int out_h_block = UP_DIV(height, C4NUM); | |||||
| int oc8 = UP_DIV(channel, C8NUM); | |||||
| for (int b = 0; b < batch; b++) { | |||||
| int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; | |||||
| int ro_batch_size = b * channel * height * width; | |||||
| const float16_t *batch_tmp_out = src + tmp_out_batch_offset; | |||||
| float16_t *batch_out = dst + ro_batch_size; | |||||
| for (int h = 0; h < height; h++) { | |||||
| int src_h_offset = h * out_w_block * C4NUM * C8NUM; | |||||
| int dst_h_offset = h * width * channel; | |||||
| for (int w = 0; w < width; w++) { | |||||
| int src_w_offset = src_h_offset + w * C8NUM; | |||||
| int dst_w_offset = dst_h_offset + w * channel; | |||||
| for (int c = 0; c < oc8 - 1; ++c) { | |||||
| int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset; | |||||
| int dst_offset = dst_w_offset + c * C8NUM; | |||||
| float16x8_t input_ptr = vld1q_f16(batch_tmp_out + src_offset); | |||||
| float16x8_t zero = vdupq_n_f16(0); | |||||
| input_ptr = vmaxq_f16(zero, input_ptr); | |||||
| vst1q_f16(batch_out + dst_offset, input_ptr); | |||||
| } | |||||
| int c_res = channel - (oc8 - 1) * C8NUM; | |||||
| int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; | |||||
| int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM; | |||||
| for (int c = 0; c < c_res; c++) { | |||||
| int src_offset = src_c_res_offset + c; | |||||
| int dst_offset = dst_c_res_offset + c; | |||||
| float16_t input_data = (batch_tmp_out + src_offset)[0]; | |||||
| input_data = input_data < 0 ? 0 : input_data; | |||||
| (batch_out + dst_offset)[0] = input_data; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) { | |||||
| int out_w_block = UP_DIV(width, C4NUM); | |||||
| int out_h_block = UP_DIV(height, C4NUM); | |||||
| int oc8 = UP_DIV(channel, C8NUM); | |||||
| for (int b = 0; b < batch; b++) { | |||||
| int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; | |||||
| int ro_batch_size = b * channel * height * width; | |||||
| const float16_t *batch_tmp_out = src + tmp_out_batch_offset; | |||||
| float16_t *batch_out = dst + ro_batch_size; | |||||
| for (int h = 0; h < height; h++) { | |||||
| int src_h_offset = h * out_w_block * C4NUM * C8NUM; | |||||
| int dst_h_offset = h * width * channel; | |||||
| for (int w = 0; w < width; w++) { | |||||
| int src_w_offset = src_h_offset + w * C8NUM; | |||||
| int dst_w_offset = dst_h_offset + w * channel; | |||||
| for (int c = 0; c < oc8 - 1; ++c) { | |||||
| int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset; | |||||
| int dst_offset = dst_w_offset + c * C8NUM; | |||||
| float16x8_t input_ptr = vld1q_f16(batch_tmp_out + src_offset); | |||||
| float16x8_t zero = vdupq_n_f16(0); | |||||
| float16x8_t six = vdupq_n_f16(6); | |||||
| input_ptr = vmaxq_f16(zero, input_ptr); | |||||
| input_ptr = vminq_f16(six, input_ptr); | |||||
| vst1q_f16(batch_out + dst_offset, input_ptr); | |||||
| } | |||||
| int c_res = channel - (oc8 - 1) * C8NUM; | |||||
| int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; | |||||
| int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM; | |||||
| for (int c = 0; c < c_res; c++) { | |||||
| int src_offset = src_c_res_offset + c; | |||||
| int dst_offset = dst_c_res_offset + c; | |||||
| float16_t input_data = (batch_tmp_out + src_offset)[0]; | |||||
| input_data = input_data < 0 ? 0 : input_data; | |||||
| input_data = input_data > 6 ? 6 : input_data; | |||||
| (batch_out + dst_offset)[0] = input_data; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // fp16 convolution winograd | // fp16 convolution winograd | ||||
| void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | ||||
| TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | ||||
| @@ -60,6 +60,12 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||||
| float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, | float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, | ||||
| int task_id, ConvParameter *conv_param); | int task_id, ConvParameter *conv_param); | ||||
| void UnPack3x3OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel); | |||||
| void UnPack3x3ReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel); | |||||
| void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel); | |||||
| // fp16 convolution winograd | // fp16 convolution winograd | ||||
| void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | ||||
| TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | ||||