| @@ -37,15 +37,11 @@ int PoolingBaseCPUKernel::SetQuantParam() { | |||||
| auto in_quant_arg = input_tensor->GetQuantParams(); | auto in_quant_arg = input_tensor->GetQuantParams(); | ||||
| auto *out_tensor = outputs_.at(kOutputIndex); | auto *out_tensor = outputs_.at(kOutputIndex); | ||||
| auto out_quant_arg = out_tensor->GetQuantParams(); | auto out_quant_arg = out_tensor->GetQuantParams(); | ||||
| if (in_quant_arg.front().scale != out_quant_arg.front().scale || | |||||
| in_quant_arg.front().zeroPoint != out_quant_arg.front().zeroPoint) { | |||||
| MS_LOG(ERROR) << "Scale/ZeroPoint of output must be equal to input's"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| pooling_quant_arg_[0][0].scale_ = in_quant_arg.front().scale; | pooling_quant_arg_[0][0].scale_ = in_quant_arg.front().scale; | ||||
| pooling_quant_arg_[0][0].zp_ = in_quant_arg.front().zeroPoint; | pooling_quant_arg_[0][0].zp_ = in_quant_arg.front().zeroPoint; | ||||
| pooling_quant_arg_[1][0].scale_ = out_quant_arg.front().scale; | pooling_quant_arg_[1][0].scale_ = out_quant_arg.front().scale; | ||||
| pooling_quant_arg_[1][0].zp_ = out_quant_arg.front().zeroPoint; | pooling_quant_arg_[1][0].zp_ = out_quant_arg.front().zeroPoint; | ||||
| pooling_param_->quant_args_ = pooling_quant_arg_; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -56,7 +56,7 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { | |||||
| int iC4 = UP_DIV(input_channel, C4NUM); | int iC4 = UP_DIV(input_channel, C4NUM); | ||||
| int oC8 = UP_DIV(output_channel, C8NUM); | int oC8 = UP_DIV(output_channel, C8NUM); | ||||
| // init weight | // init weight | ||||
| size_t transformed_size = iC4 * C8NUM * oC8 * C8NUM * 36 * sizeof(float16_t); | |||||
| size_t transformed_size = iC4 * C4NUM * oC8 * C8NUM * 36 * sizeof(float16_t); | |||||
| transformed_filter_addr_ = reinterpret_cast<float16_t *>(malloc(transformed_size)); | transformed_filter_addr_ = reinterpret_cast<float16_t *>(malloc(transformed_size)); | ||||
| if (transformed_filter_addr_ == nullptr) { | if (transformed_filter_addr_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed."; | MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed."; | ||||
| @@ -101,6 +101,8 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||||
| int k_plane = 36; | int k_plane = 36; | ||||
| int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); | int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); | ||||
| int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM); | int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM); | ||||
| /*=============================tile_buffer_============================*/ | |||||
| size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC4 * C4NUM * sizeof(float16_t); | size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC4 * C4NUM * sizeof(float16_t); | ||||
| tile_buffer_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size)); | tile_buffer_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size)); | ||||
| if (tile_buffer_ == nullptr) { | if (tile_buffer_ == nullptr) { | ||||
| @@ -109,6 +111,7 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tile_buffer_, 0, tile_buffer_size); | memset(tile_buffer_, 0, tile_buffer_size); | ||||
| /*=============================block_unit_buffer_============================*/ | |||||
| size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float16_t); | size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float16_t); | ||||
| block_unit_buffer_ = reinterpret_cast<float16_t *>(malloc(block_unit_buffer_size)); | block_unit_buffer_ = reinterpret_cast<float16_t *>(malloc(block_unit_buffer_size)); | ||||
| if (block_unit_buffer_ == nullptr) { | if (block_unit_buffer_ == nullptr) { | ||||
| @@ -117,6 +120,7 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(block_unit_buffer_, 0, block_unit_buffer_size); | memset(block_unit_buffer_, 0, block_unit_buffer_size); | ||||
| /*=============================tmp_dst_buffer_============================*/ | |||||
| size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float16_t); | size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float16_t); | ||||
| tmp_dst_buffer_ = reinterpret_cast<float16_t *>(malloc(tmp_dst_buffer_size)); | tmp_dst_buffer_ = reinterpret_cast<float16_t *>(malloc(tmp_dst_buffer_size)); | ||||
| if (tmp_dst_buffer_ == nullptr) { | if (tmp_dst_buffer_ == nullptr) { | ||||
| @@ -125,6 +129,7 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); | memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); | ||||
| /*=============================tmp_out_============================*/ | |||||
| size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * | size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * | ||||
| tile_num * sizeof(float16_t); | tile_num * sizeof(float16_t); | ||||
| tmp_out_ = reinterpret_cast<float16_t *>(malloc(tmp_out_size)); | tmp_out_ = reinterpret_cast<float16_t *>(malloc(tmp_out_size)); | ||||
| @@ -134,6 +139,7 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tmp_out_, 0, tmp_out_size); | memset(tmp_out_, 0, tmp_out_size); | ||||
| /*=============================fp16_input_============================*/ | |||||
| size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * | size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * | ||||
| conv_param_->input_w_ * sizeof(float16_t); | conv_param_->input_w_ * sizeof(float16_t); | ||||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | ||||
| @@ -143,7 +149,7 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(fp16_input_, 0, fp16_input_size); | memset(fp16_input_, 0, fp16_input_size); | ||||
| // init nhwc4 input | |||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = | size_t nhwc4_input_size = | ||||
| iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | ||||
| nhwc4_input_ = malloc(nhwc4_input_size); | nhwc4_input_ = malloc(nhwc4_input_size); | ||||
| @@ -152,12 +158,19 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | memset(nhwc4_input_, 0, nhwc4_input_size); | ||||
| /*=============================fp16_out_============================*/ | |||||
| size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ * | |||||
| conv_param_->output_w_ * sizeof(float16_t); | |||||
| fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size)); | |||||
| if (fp16_out_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_out_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void Convolution3x3FP16CPUKernel::ConfigInputOutput() { | void Convolution3x3FP16CPUKernel::ConfigInputOutput() { | ||||
| auto output_tensor = outputs_.at(kOutputIndex); | |||||
| output_tensor->SetFormat(schema::Format_NHWC); | |||||
| auto input_tensor = inputs_.at(kInputIndex); | auto input_tensor = inputs_.at(kInputIndex); | ||||
| auto input_format = input_tensor->GetFormat(); | auto input_format = input_tensor->GetFormat(); | ||||
| schema::Format execute_format = schema::Format_NHWC4; | schema::Format execute_format = schema::Format_NHWC4; | ||||
| @@ -201,6 +214,15 @@ int Convolution3x3FP16CPUKernel::ReSize() { | |||||
| if (tmp_out_ != nullptr) { | if (tmp_out_ != nullptr) { | ||||
| free(tmp_out_); | free(tmp_out_); | ||||
| } | } | ||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| if (fp16_input_ != nullptr) { | |||||
| free(fp16_input_); | |||||
| } | |||||
| if (nhwc4_input_ != nullptr) { | |||||
| free(nhwc4_input_); | |||||
| } | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | auto ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -216,9 +238,8 @@ int Convolution3x3FP16CPUKernel::ReSize() { | |||||
| } | } | ||||
| int Convolution3x3FP16CPUKernel::RunImpl(int task_id) { | int Convolution3x3FP16CPUKernel::RunImpl(int task_id) { | ||||
| auto output_addr = reinterpret_cast<float16_t *>(outputs_.at(kOutputIndex)->Data()); | |||||
| Conv3x3Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_, | Conv3x3Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_, | ||||
| reinterpret_cast<float16_t *>(bias_data_), output_addr, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, | |||||
| reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, | |||||
| tmp_out_, task_id, conv_param_); | tmp_out_, task_id, conv_param_); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -234,12 +255,13 @@ int Convolution3x3Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) | |||||
| } | } | ||||
| int Convolution3x3FP16CPUKernel::Run() { | int Convolution3x3FP16CPUKernel::Run() { | ||||
| // cast fp32 input data to fp16 | |||||
| auto input_tensor = inputs_.at(kInputIndex); | auto input_tensor = inputs_.at(kInputIndex); | ||||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | ||||
| // cast fp32 input data to fp16 | |||||
| for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | ||||
| fp16_input_[i] = (float16_t)ori_input_data[i]; | fp16_input_[i] = (float16_t)ori_input_data[i]; | ||||
| } | } | ||||
| int in_batch = conv_param_->input_batch_; | int in_batch = conv_param_->input_batch_; | ||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| @@ -251,6 +273,13 @@ int Convolution3x3FP16CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "conv3x3 fp16 error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "conv3x3 fp16 error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // cast fp16 out to fp32 data | |||||
| auto out_tensor = outputs_.at(kOutputIndex); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||||
| for (int j = 0; j < out_tensor->ElementsNum(); ++j) { | |||||
| output_addr[j] = (float)fp16_out_[j]; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -36,6 +36,9 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| if (fp16_weight_ != nullptr) { | if (fp16_weight_ != nullptr) { | ||||
| free(fp16_weight_); | free(fp16_weight_); | ||||
| } | } | ||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| if (transformed_filter_addr_ != nullptr) { | if (transformed_filter_addr_ != nullptr) { | ||||
| free(transformed_filter_addr_); | free(transformed_filter_addr_); | ||||
| } | } | ||||
| @@ -64,6 +67,7 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| private: | private: | ||||
| float16_t *fp16_input_; | float16_t *fp16_input_; | ||||
| float16_t *fp16_weight_; | float16_t *fp16_weight_; | ||||
| float16_t *fp16_out_; | |||||
| float16_t *transformed_filter_addr_; | float16_t *transformed_filter_addr_; | ||||
| float16_t *tile_buffer_; | float16_t *tile_buffer_; | ||||
| float16_t *block_unit_buffer_; | float16_t *block_unit_buffer_; | ||||
| @@ -37,9 +37,9 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| int out_channel = conv_param_->output_channel_; | int out_channel = conv_param_->output_channel_; | ||||
| int oc8 = UP_DIV(out_channel, C8NUM); | int oc8 = UP_DIV(out_channel, C8NUM); | ||||
| int channel_block = UP_DIV(in_channel, C4NUM); | |||||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||||
| int kernel_plane = kernel_h * kernel_w; | int kernel_plane = kernel_h * kernel_w; | ||||
| int pack_weight_size = oc8 * channel_block * C8NUM * C4NUM * kernel_plane; | |||||
| int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; | |||||
| // init weight | // init weight | ||||
| float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data()); | float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data()); | ||||
| @@ -49,10 +49,10 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(fp16_weight_, 0, fp16_weight_size); | |||||
| for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | ||||
| fp16_weight_[i] = (float16_t)origin_weight[i]; | fp16_weight_[i] = (float16_t)origin_weight[i]; | ||||
| } | } | ||||
| packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t))); | packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t))); | ||||
| if (packed_weight_ == nullptr) { | if (packed_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc packed_weight_ failed."; | MS_LOG(ERROR) << "malloc packed_weight_ failed."; | ||||
| @@ -95,6 +95,8 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||||
| int output_tile_count = UP_DIV(output_count, cal_num); | int output_tile_count = UP_DIV(output_count, cal_num); | ||||
| int unit_size = kernel_plane * channel_block * C4NUM; | int unit_size = kernel_plane * channel_block * C4NUM; | ||||
| int packed_input_size = output_tile_count * cal_num * unit_size; | int packed_input_size = output_tile_count * cal_num * unit_size; | ||||
| /*=============================packed_input_============================*/ | |||||
| packed_input_ = reinterpret_cast<float16_t *>(malloc(in_batch * packed_input_size * sizeof(float16_t))); | packed_input_ = reinterpret_cast<float16_t *>(malloc(in_batch * packed_input_size * sizeof(float16_t))); | ||||
| if (packed_input_ == nullptr) { | if (packed_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc packed_input_ failed."; | MS_LOG(ERROR) << "malloc packed_input_ failed."; | ||||
| @@ -102,6 +104,7 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float16_t)); | memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float16_t)); | ||||
| /*=============================fp16_input_============================*/ | |||||
| size_t fp16_input_size = | size_t fp16_input_size = | ||||
| in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | ||||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | ||||
| @@ -109,8 +112,8 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||||
| MS_LOG(ERROR) << "malloc fp16_input_ failed."; | MS_LOG(ERROR) << "malloc fp16_input_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(fp16_input_, 0, fp16_input_size); | |||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * | size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * | ||||
| conv_param_->input_w_ * sizeof(float16_t); | conv_param_->input_w_ * sizeof(float16_t); | ||||
| nhwc4_input_ = malloc(nhwc4_input_size); | nhwc4_input_ = malloc(nhwc4_input_size); | ||||
| @@ -120,11 +123,21 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | memset(nhwc4_input_, 0, nhwc4_input_size); | ||||
| /*=============================tmp_output_block_============================*/ | |||||
| tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(cal_num * out_channel * sizeof(float16_t))); | tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(cal_num * out_channel * sizeof(float16_t))); | ||||
| if (tmp_output_block_ == nullptr) { | if (tmp_output_block_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; | MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| /*=============================fp16_out_============================*/ | |||||
| size_t fp16_output_size = | |||||
| out_channel * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float16_t); | |||||
| fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size)); | |||||
| if (fp16_out_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_out_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -137,8 +150,6 @@ void ConvolutionFP16CPUKernel::ConfigInputOutput() { | |||||
| MS_LOG(ERROR) << "layout convert func is nullptr."; | MS_LOG(ERROR) << "layout convert func is nullptr."; | ||||
| return; | return; | ||||
| } | } | ||||
| auto output_tensor = outputs_.at(kOutputIndex); | |||||
| output_tensor->SetFormat(schema::Format_NHWC); | |||||
| } | } | ||||
| int ConvolutionFP16CPUKernel::Init() { | int ConvolutionFP16CPUKernel::Init() { | ||||
| @@ -171,6 +182,12 @@ int ConvolutionFP16CPUKernel::ReSize() { | |||||
| if (nhwc4_input_ != nullptr) { | if (nhwc4_input_ != nullptr) { | ||||
| free(nhwc4_input_); | free(nhwc4_input_); | ||||
| } | } | ||||
| if (fp16_input_ != nullptr) { | |||||
| free(fp16_input_); | |||||
| } | |||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | auto ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -186,9 +203,8 @@ int ConvolutionFP16CPUKernel::ReSize() { | |||||
| } | } | ||||
| int ConvolutionFP16CPUKernel::RunImpl(int task_id) { | int ConvolutionFP16CPUKernel::RunImpl(int task_id) { | ||||
| auto output_addr = reinterpret_cast<float16_t *>(outputs_.at(kOutputIndex)->Data()); | |||||
| ConvFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_input_, packed_weight_, | ConvFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_input_, packed_weight_, | ||||
| reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, output_addr, task_id, conv_param_); | |||||
| reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, fp16_out_, task_id, conv_param_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -203,12 +219,13 @@ int ConvolutionFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| } | } | ||||
| int ConvolutionFP16CPUKernel::Run() { | int ConvolutionFP16CPUKernel::Run() { | ||||
| // cast fp32 input data to fp16 | |||||
| auto input_tensor = inputs_.at(kInputIndex); | auto input_tensor = inputs_.at(kInputIndex); | ||||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | ||||
| // cast fp32 input data to fp16 | |||||
| for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | ||||
| fp16_input_[i] = (float16_t)ori_input_data[i]; | fp16_input_[i] = (float16_t)ori_input_data[i]; | ||||
| } | } | ||||
| int in_batch = conv_param_->input_batch_; | int in_batch = conv_param_->input_batch_; | ||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| @@ -220,6 +237,13 @@ int ConvolutionFP16CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // cast fp16 out to fp32 data | |||||
| auto out_tensor = outputs_.at(kOutputIndex); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||||
| for (int j = 0; j < out_tensor->ElementsNum(); ++j) { | |||||
| output_addr[j] = (float)fp16_out_[j]; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -23,15 +23,11 @@ | |||||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | #include "src/runtime/kernel/arm/base/convolution_base.h" | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| typedef void (*FP16_GEMM_FUNC)(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||||
| size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, | |||||
| size_t relu6); | |||||
| class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { | class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { | ||||
| public: | public: | ||||
| ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx) | const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx) | ||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} | |||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~ConvolutionFP16CPUKernel() override { | ~ConvolutionFP16CPUKernel() override { | ||||
| if (fp16_input_ != nullptr) { | if (fp16_input_ != nullptr) { | ||||
| free(fp16_input_); | free(fp16_input_); | ||||
| @@ -39,6 +35,9 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| if (fp16_weight_ != nullptr) { | if (fp16_weight_ != nullptr) { | ||||
| free(fp16_weight_); | free(fp16_weight_); | ||||
| } | } | ||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| if (packed_input_ != nullptr) { | if (packed_input_ != nullptr) { | ||||
| free(packed_input_); | free(packed_input_); | ||||
| } | } | ||||
| @@ -59,15 +58,13 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| void ConfigInputOutput(); | void ConfigInputOutput(); | ||||
| private: | private: | ||||
| bool support_fp16_ = true; | |||||
| float16_t *fp16_input_; | float16_t *fp16_input_; | ||||
| float16_t *fp16_weight_; | float16_t *fp16_weight_; | ||||
| float16_t *fp16_out_; | |||||
| float16_t *packed_input_; | float16_t *packed_input_; | ||||
| float16_t *packed_weight_; | float16_t *packed_weight_; | ||||
| float16_t *tmp_output_block_; | float16_t *tmp_output_block_; | ||||
| FP16_GEMM_FUNC gemm_func_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ | ||||
| @@ -89,6 +89,7 @@ int ConvolutionCPUKernel::InitTmpBuffer() { | |||||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | int output_tile_count = UP_DIV(output_count, TILE_NUM); | ||||
| int unit_size = kernel_plane * ic4 * C4NUM; | int unit_size = kernel_plane * ic4 * C4NUM; | ||||
| int packed_input_size = output_tile_count * TILE_NUM * unit_size; | int packed_input_size = output_tile_count * TILE_NUM * unit_size; | ||||
| /*=============================packed_input============================*/ | |||||
| packed_input_ = reinterpret_cast<float *>(malloc(in_batch * packed_input_size * sizeof(float))); | packed_input_ = reinterpret_cast<float *>(malloc(in_batch * packed_input_size * sizeof(float))); | ||||
| if (packed_input_ == nullptr) { | if (packed_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc packed input failed."; | MS_LOG(ERROR) << "malloc packed input failed."; | ||||
| @@ -96,6 +97,7 @@ int ConvolutionCPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float)); | memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float)); | ||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = | size_t nhwc4_input_size = | ||||
| ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); | ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); | ||||
| nhwc4_input_ = malloc(nhwc4_input_size); | nhwc4_input_ = malloc(nhwc4_input_size); | ||||
| @@ -105,7 +107,7 @@ int ConvolutionCPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | memset(nhwc4_input_, 0, nhwc4_input_size); | ||||
| // tmp out | |||||
| /*=============================tmp_output_block_============================*/ | |||||
| tmp_output_block_ = reinterpret_cast<float *>(malloc(TILE_NUM * out_channel * sizeof(float))); | tmp_output_block_ = reinterpret_cast<float *>(malloc(TILE_NUM * out_channel * sizeof(float))); | ||||
| if (tmp_output_block_ == nullptr) { | if (tmp_output_block_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp output block failed."; | MS_LOG(ERROR) << "malloc tmp output block failed."; | ||||
| @@ -94,6 +94,8 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); | int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); | ||||
| int oC4 = UP_DIV(conv_param_->output_channel_, C4NUM); | int oC4 = UP_DIV(conv_param_->output_channel_, C4NUM); | ||||
| int k_plane = 16; | int k_plane = 16; | ||||
| /*=============================tile_buffer_============================*/ | |||||
| size_t tile_buffer_size = thread_count_ * TILE_NUM * k_plane * iC4 * C4NUM * sizeof(float); | size_t tile_buffer_size = thread_count_ * TILE_NUM * k_plane * iC4 * C4NUM * sizeof(float); | ||||
| tile_buffer_ = reinterpret_cast<float *>(malloc(tile_buffer_size)); | tile_buffer_ = reinterpret_cast<float *>(malloc(tile_buffer_size)); | ||||
| if (tile_buffer_ == nullptr) { | if (tile_buffer_ == nullptr) { | ||||
| @@ -102,6 +104,7 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tile_buffer_, 0, tile_buffer_size); | memset(tile_buffer_, 0, tile_buffer_size); | ||||
| /*=============================block_unit_buffer_============================*/ | |||||
| size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float); | size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float); | ||||
| block_unit_buffer_ = reinterpret_cast<float *>(malloc(block_unit_buffer_size)); | block_unit_buffer_ = reinterpret_cast<float *>(malloc(block_unit_buffer_size)); | ||||
| if (block_unit_buffer_ == nullptr) { | if (block_unit_buffer_ == nullptr) { | ||||
| @@ -110,6 +113,7 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(block_unit_buffer_, 0, block_unit_buffer_size); | memset(block_unit_buffer_, 0, block_unit_buffer_size); | ||||
| /*=============================tmp_dst_buffer_============================*/ | |||||
| size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * k_plane * oC4 * C4NUM * sizeof(float); | size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * k_plane * oC4 * C4NUM * sizeof(float); | ||||
| tmp_dst_buffer_ = reinterpret_cast<float *>(malloc(tmp_dst_buffer_size)); | tmp_dst_buffer_ = reinterpret_cast<float *>(malloc(tmp_dst_buffer_size)); | ||||
| if (tmp_dst_buffer_ == nullptr) { | if (tmp_dst_buffer_ == nullptr) { | ||||
| @@ -118,6 +122,7 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); | memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); | ||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = | size_t nhwc4_input_size = | ||||
| iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); | iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); | ||||
| nhwc4_input_ = malloc(nhwc4_input_size); | nhwc4_input_ = malloc(nhwc4_input_size); | ||||
| @@ -127,6 +132,7 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | memset(nhwc4_input_, 0, nhwc4_input_size); | ||||
| /*=============================nc4hw4_out_============================*/ | |||||
| size_t nc4hw4_out_size = | size_t nc4hw4_out_size = | ||||
| oC4 * C4NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float); | oC4 * C4NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float); | ||||
| nc4hw4_out_ = reinterpret_cast<float *>(malloc(nc4hw4_out_size)); | nc4hw4_out_ = reinterpret_cast<float *>(malloc(nc4hw4_out_size)); | ||||
| @@ -165,6 +165,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| int ic4 = UP_DIV(channel_in, C4NUM); | int ic4 = UP_DIV(channel_in, C4NUM); | ||||
| int oc4 = UP_DIV(channel_out, C4NUM); | int oc4 = UP_DIV(channel_out, C4NUM); | ||||
| /*=============================trans_input_============================*/ | |||||
| size_t tile_buffer_size = thread_count_ * TILE_NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); | size_t tile_buffer_size = thread_count_ * TILE_NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); | ||||
| trans_input_ = reinterpret_cast<float *>(malloc(tile_buffer_size)); | trans_input_ = reinterpret_cast<float *>(malloc(tile_buffer_size)); | ||||
| if (trans_input_ == nullptr) { | if (trans_input_ == nullptr) { | ||||
| @@ -173,6 +174,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(trans_input_, 0, tile_buffer_size); | memset(trans_input_, 0, tile_buffer_size); | ||||
| /*=============================gemm_out_============================*/ | |||||
| gemm_out_ = reinterpret_cast<float *>( | gemm_out_ = reinterpret_cast<float *>( | ||||
| malloc(thread_count_ * TILE_NUM * input_unit_ * input_unit_ * oc4 * C4NUM * sizeof(float))); | malloc(thread_count_ * TILE_NUM * input_unit_ * input_unit_ * oc4 * C4NUM * sizeof(float))); | ||||
| if (gemm_out_ == nullptr) { | if (gemm_out_ == nullptr) { | ||||
| @@ -180,6 +182,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| /*=============================tmp_out_data_============================*/ | |||||
| int out_w_block = UP_DIV(output_w, output_unit_); | int out_w_block = UP_DIV(output_w, output_unit_); | ||||
| int out_h_block = UP_DIV(output_h, output_unit_); | int out_h_block = UP_DIV(output_h, output_unit_); | ||||
| tmp_out_data_ = reinterpret_cast<float *>( | tmp_out_data_ = reinterpret_cast<float *>( | ||||
| @@ -189,7 +192,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| tmp_data_ = reinterpret_cast<float *>(malloc(C4NUM * input_unit_ * input_unit_ * sizeof(float))); | |||||
| /*=============================tmp_data_============================*/ | |||||
| tmp_data_ = reinterpret_cast<float *>(malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float))); | |||||
| if (tmp_data_ == nullptr) { | if (tmp_data_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp_data_ failed."; | MS_LOG(ERROR) << "malloc tmp_data_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -201,6 +205,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| tmp_buffer_address_list_[2] = tmp_out_data_; | tmp_buffer_address_list_[2] = tmp_out_data_; | ||||
| tmp_buffer_address_list_[3] = tmp_data_; | tmp_buffer_address_list_[3] = tmp_data_; | ||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = | size_t nhwc4_input_size = | ||||
| ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); | ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); | ||||
| nhwc4_input_ = malloc(nhwc4_input_size); | nhwc4_input_ = malloc(nhwc4_input_size); | ||||
| @@ -108,6 +108,7 @@ int Convolution3x3Int8CPUKernel::InitTmpBuffer() { | |||||
| int output_w = conv_param_->output_w_; | int output_w = conv_param_->output_w_; | ||||
| int output_h = conv_param_->output_h_; | int output_h = conv_param_->output_h_; | ||||
| /*=============================tile_buffer_============================*/ | |||||
| size_t tile_buffer_size = thread_count_ * TILE_NUM * 16 * ic8 * C8NUM * sizeof(int16_t); | size_t tile_buffer_size = thread_count_ * TILE_NUM * 16 * ic8 * C8NUM * sizeof(int16_t); | ||||
| tile_buffer_ = reinterpret_cast<int16_t *>(malloc(tile_buffer_size)); | tile_buffer_ = reinterpret_cast<int16_t *>(malloc(tile_buffer_size)); | ||||
| if (tile_buffer_ == nullptr) { | if (tile_buffer_ == nullptr) { | ||||
| @@ -116,6 +117,7 @@ int Convolution3x3Int8CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tile_buffer_, 0, tile_buffer_size); | memset(tile_buffer_, 0, tile_buffer_size); | ||||
| /*=============================block_unit_buffer_============================*/ | |||||
| size_t block_unit_buffer_size = thread_count_ * 4 * 4 * C8NUM * sizeof(int16_t); | size_t block_unit_buffer_size = thread_count_ * 4 * 4 * C8NUM * sizeof(int16_t); | ||||
| block_unit_buffer_ = reinterpret_cast<int16_t *>(malloc(block_unit_buffer_size)); | block_unit_buffer_ = reinterpret_cast<int16_t *>(malloc(block_unit_buffer_size)); | ||||
| if (block_unit_buffer_ == nullptr) { | if (block_unit_buffer_ == nullptr) { | ||||
| @@ -124,6 +126,7 @@ int Convolution3x3Int8CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(block_unit_buffer_, 0, block_unit_buffer_size); | memset(block_unit_buffer_, 0, block_unit_buffer_size); | ||||
| /*=============================tmp_dst_buffer_============================*/ | |||||
| size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * 16 * oc4 * C4NUM * sizeof(int32_t); | size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * 16 * oc4 * C4NUM * sizeof(int32_t); | ||||
| tmp_dst_buffer_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_buffer_size)); | tmp_dst_buffer_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_buffer_size)); | ||||
| if (tmp_dst_buffer_ == nullptr) { | if (tmp_dst_buffer_ == nullptr) { | ||||
| @@ -132,6 +135,7 @@ int Convolution3x3Int8CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); | memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); | ||||
| /*=============================tmp_out_============================*/ | |||||
| size_t tmp_out_size = oc4 * C4NUM * output_batch * output_w * output_h * sizeof(uint8_t); | size_t tmp_out_size = oc4 * C4NUM * output_batch * output_w * output_h * sizeof(uint8_t); | ||||
| tmp_out_ = reinterpret_cast<int8_t *>(malloc(tmp_out_size)); | tmp_out_ = reinterpret_cast<int8_t *>(malloc(tmp_out_size)); | ||||
| if (tmp_out_ == nullptr) { | if (tmp_out_ == nullptr) { | ||||
| @@ -140,6 +144,7 @@ int Convolution3x3Int8CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tmp_out_, 0, tmp_out_size); | memset(tmp_out_, 0, tmp_out_size); | ||||
| /*=============================input_data_============================*/ | |||||
| size_t c8_input_size = in_batch * input_h * input_w * ic8 * C8NUM * sizeof(int16_t); | size_t c8_input_size = in_batch * input_h * input_w * ic8 * C8NUM * sizeof(int16_t); | ||||
| input_data_ = reinterpret_cast<int16_t *>(malloc(c8_input_size)); | input_data_ = reinterpret_cast<int16_t *>(malloc(c8_input_size)); | ||||
| if (input_data_ == nullptr) { | if (input_data_ == nullptr) { | ||||
| @@ -238,4 +243,3 @@ int Convolution3x3Int8CPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -116,6 +116,7 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { | |||||
| int unit_size = plane_c4 * C4NUM * ic4 * C4NUM; | int unit_size = plane_c4 * C4NUM * ic4 * C4NUM; | ||||
| int packed_input_size = output_tile_count * tile_num_ * unit_size; | int packed_input_size = output_tile_count * tile_num_ * unit_size; | ||||
| /*=============================packed_input_============================*/ | |||||
| packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size)); | packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size)); | ||||
| if (packed_input_ == nullptr) { | if (packed_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc packed_input_ failed."; | MS_LOG(ERROR) << "malloc packed_input_ failed."; | ||||
| @@ -123,6 +124,7 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); | memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); | ||||
| /*=============================input_sum_============================*/ | |||||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); | input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); | ||||
| if (input_sum_ == nullptr) { | if (input_sum_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc input_sum_ failed."; | MS_LOG(ERROR) << "malloc input_sum_ failed."; | ||||
| @@ -130,6 +132,7 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); | memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); | ||||
| /*=============================tmp_dst_============================*/ | |||||
| size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); | size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); | ||||
| tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size)); | tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size)); | ||||
| if (tmp_dst_ == nullptr) { | if (tmp_dst_ == nullptr) { | ||||
| @@ -138,12 +141,14 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(tmp_dst_, 0, tmp_dst_size); | memset(tmp_dst_, 0, tmp_dst_size); | ||||
| /*=============================tmp_out_============================*/ | |||||
| tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); | tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); | ||||
| if (tmp_out_ == nullptr) { | if (tmp_out_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp_out_ failed."; | MS_LOG(ERROR) << "malloc tmp_out_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; | size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; | ||||
| nhwc4_input_ = malloc(nhwc4_input_size); | nhwc4_input_ = malloc(nhwc4_input_size); | ||||
| if (nhwc4_input_ == nullptr) { | if (nhwc4_input_ == nullptr) { | ||||
| @@ -209,6 +214,7 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { | |||||
| int unit_size = kernel_plane * ic4 * C4NUM; | int unit_size = kernel_plane * ic4 * C4NUM; | ||||
| int packed_input_size = output_tile_count * tile_num_ * unit_size; | int packed_input_size = output_tile_count * tile_num_ * unit_size; | ||||
| /*=============================packed_input_============================*/ | |||||
| packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size)); | packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size)); | ||||
| if (packed_input_ == nullptr) { | if (packed_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc packed_input_ failed."; | MS_LOG(ERROR) << "malloc packed_input_ failed."; | ||||
| @@ -216,6 +222,7 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { | |||||
| } | } | ||||
| memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); | memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); | ||||
| /*=============================input_sum_============================*/ | |||||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); | input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); | ||||
| if (input_sum_ == nullptr) { | if (input_sum_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc input_sum_ failed."; | MS_LOG(ERROR) << "malloc input_sum_ failed."; | ||||
| @@ -223,6 +230,7 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { | |||||
| } | } | ||||
| memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); | memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); | ||||
| /*=============================tmp_dst_============================*/ | |||||
| size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); | size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); | ||||
| tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size)); | tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size)); | ||||
| if (tmp_dst_ == nullptr) { | if (tmp_dst_ == nullptr) { | ||||
| @@ -231,12 +239,14 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { | |||||
| } | } | ||||
| memset(tmp_dst_, 0, tmp_dst_size); | memset(tmp_dst_, 0, tmp_dst_size); | ||||
| /*=============================tmp_out_============================*/ | |||||
| tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); | tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); | ||||
| if (tmp_out_ == nullptr) { | if (tmp_out_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp_out_ failed."; | MS_LOG(ERROR) << "malloc tmp_out_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; | size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; | ||||
| nhwc4_input_ = malloc(nhwc4_input_size); | nhwc4_input_ = malloc(nhwc4_input_size); | ||||
| if (nhwc4_input_ == nullptr) { | if (nhwc4_input_ == nullptr) { | ||||
| @@ -54,7 +54,13 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh | |||||
| } | } | ||||
| } | } | ||||
| (output + out_tile_offset)[0] = tmp_out; | |||||
| (output + out_tile_offset)[0] = tmp_out + bias[i]; | |||||
| if (relu) { | |||||
| (output + out_tile_offset)[0] = (output + out_tile_offset)[0] < 0 ? 0 : (output + out_tile_offset)[0]; | |||||
| } else if (relu6) { | |||||
| (output + out_tile_offset)[0] = (output + out_tile_offset)[0] < 0 ? 0 : (output + out_tile_offset)[0]; | |||||
| (output + out_tile_offset)[0] = (output + out_tile_offset)[0] > 6 ? 6 : (output + out_tile_offset)[0]; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -111,7 +117,8 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||||
| int out_h = conv_param->output_h_; | int out_h = conv_param->output_h_; | ||||
| int out_w = conv_param->output_w_; | int out_w = conv_param->output_w_; | ||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| bool relu = conv_param->is_relu_; | |||||
| bool relu6 = conv_param->is_relu6_; | |||||
| // todo | // todo | ||||
| int thread_count = conv_param->thread_num_; | int thread_count = conv_param->thread_num_; | ||||
| int tile_n = 16; | int tile_n = 16; | ||||
| @@ -125,7 +132,6 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||||
| // we accumulate 4 channels per time for input blocks | // we accumulate 4 channels per time for input blocks | ||||
| int ic4 = UP_DIV(in_channel, C4NUM); | int ic4 = UP_DIV(in_channel, C4NUM); | ||||
| int oc8 = UP_DIV(in_channel, C8NUM); | |||||
| int conv_depth = kernel_h * kernel_w; | int conv_depth = kernel_h * kernel_w; | ||||
| // bytes from one output's i-th channel to the next output's i-th channel | // bytes from one output's i-th channel to the next output's i-th channel | ||||
| // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward | // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward | ||||
| @@ -137,19 +143,18 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | ||||
| int start_index = thread_id * tile_n; | int start_index = thread_id * tile_n; | ||||
| int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; | int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; | ||||
| float16_t *gemm_input = | |||||
| (float16_t *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); | |||||
| float16_t *gemm_input = (float16_t *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); | |||||
| Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | ||||
| int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | ||||
| if (real_cal_num == tile_n) { | if (real_cal_num == tile_n) { | ||||
| float16_t *gemm_output = output_data + out_offset; | float16_t *gemm_output = output_data + out_offset; | ||||
| IndirectGemmFp16_16x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, | IndirectGemmFp16_16x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, | ||||
| oc8 * C8NUM * sizeof(float16_t), 0, 0, 0, 0); | |||||
| out_channel * sizeof(float16_t), 0, 0, relu, relu6); | |||||
| } else { | } else { | ||||
| // res part | // res part | ||||
| IndirectGemmFp16_16x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, | IndirectGemmFp16_16x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, | ||||
| oc8 * C8NUM * sizeof(float16_t), 0, 0, 0, 0); | |||||
| out_channel * sizeof(float16_t), 0, 0, relu, relu6); | |||||
| memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float16_t)); | memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float16_t)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -196,6 +201,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||||
| // get real output | // get real output | ||||
| // todo | // todo | ||||
| bool relu = conv_param->is_relu_; | |||||
| bool relu6 = conv_param->is_relu6_; | |||||
| for (int batch = 0; batch < output_batch; batch++) { | for (int batch = 0; batch < output_batch; batch++) { | ||||
| int batch_size = batch * output_channel * output_h * output_w; | int batch_size = batch * output_channel * output_h * output_w; | ||||
| for (int h = 0; h < output_h; h++) { | for (int h = 0; h < output_h; h++) { | ||||
| @@ -207,10 +214,14 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||||
| C8NUM * (h * out_w_block * output_unit + w) + oc8_res; | C8NUM * (h * out_w_block * output_unit + w) + oc8_res; | ||||
| int dst_offset = (h * output_w + w) * output_channel + c; | int dst_offset = (h * output_w + w) * output_channel + c; | ||||
| (output_data + dst_offset)[0] = (tmp_out + src_offset)[0]; | (output_data + dst_offset)[0] = (tmp_out + src_offset)[0]; | ||||
| if (relu) { | |||||
| (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; | |||||
| } else if (relu6) { | |||||
| (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; | |||||
| (output_data + dst_offset)[0] = (output_data + dst_offset)[0] > 6 ? 6 : (output_data + dst_offset)[0]; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -56,10 +56,14 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 | |||||
| for (int m = 0; m < channel_block; m++) { | for (int m = 0; m < channel_block; m++) { | ||||
| int channel_block_stride = input_x_stride + m * C4NUM; | int channel_block_stride = input_x_stride + m * C4NUM; | ||||
| int channel_block_offset = input_plane_offset + m * 16 * C4NUM; | int channel_block_offset = input_plane_offset + m * 16 * C4NUM; | ||||
| #ifdef ENABLE_ARM64 | |||||
| vst1_f16(packed_input + channel_block_offset, vld1_f16(input_data + channel_block_stride)); | |||||
| #else | |||||
| (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; | (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; | ||||
| (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; | (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; | ||||
| (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; | (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; | ||||
| (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; | (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; | ||||
| #endif | |||||
| } // channel_block loop | } // channel_block loop | ||||
| } // kernel_w loop | } // kernel_w loop | ||||
| } // kernel_h loop | } // kernel_h loop | ||||
| @@ -459,25 +459,30 @@ void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data | |||||
| float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); | float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); | ||||
| float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); | float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); | ||||
| float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04); | |||||
| float16x8_t d01 = vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)); | |||||
| float16x8_t d02 = vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)); | |||||
| float16x8_t d03 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05); | |||||
| float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14); | |||||
| float16x8_t d11 = vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)); | |||||
| float16x8_t d12 = vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)); | |||||
| float16x8_t d13 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15); | |||||
| float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24); | |||||
| float16x8_t d21 = vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)); | |||||
| float16x8_t d22 = vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)); | |||||
| float16x8_t d23 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25); | |||||
| float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34); | |||||
| float16x8_t d31 = vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)); | |||||
| float16x8_t d32 = vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)); | |||||
| float16x8_t d33 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35); | |||||
| float16x8_t bias_ptr = vld1q_f16(bias_data); | |||||
| float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04), bias_ptr); | |||||
| float16x8_t d01 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)), bias_ptr); | |||||
| float16x8_t d02 = vaddq_f16(vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)), bias_ptr); | |||||
| float16x8_t d03 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05), bias_ptr); | |||||
| float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14), bias_ptr); | |||||
| float16x8_t d11 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)), bias_ptr); | |||||
| float16x8_t d12 = vaddq_f16(vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)), bias_ptr); | |||||
| float16x8_t d13 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15), bias_ptr); | |||||
| float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24), bias_ptr); | |||||
| float16x8_t d21 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)), bias_ptr); | |||||
| float16x8_t d22 = vaddq_f16(vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)), bias_ptr); | |||||
| float16x8_t d23 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25), bias_ptr); | |||||
| float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34), bias_ptr); | |||||
| float16x8_t d31 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)), bias_ptr); | |||||
| float16x8_t d32 = vaddq_f16(vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)), bias_ptr); | |||||
| float16x8_t d33 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35), bias_ptr); | |||||
| vst1q_f16(output_data, d00); | vst1q_f16(output_data, d00); | ||||
| vst1q_f16(output_data + 8, d01); | vst1q_f16(output_data + 8, d01); | ||||
| @@ -103,6 +103,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||||
| float *gemm_out = buffer_list[1]; | float *gemm_out = buffer_list[1]; | ||||
| float *tmp_out_data = buffer_list[2]; | float *tmp_out_data = buffer_list[2]; | ||||
| float *tmp_data = buffer_list[3]; | float *tmp_data = buffer_list[3]; | ||||
| int trans_input_offset = TILE_NUM * input_unit_square * ic4 * C4NUM; | |||||
| int gemm_out_offset = TILE_NUM * input_unit_square * oc4 * C4NUM; | |||||
| int tmp_data_offset = input_unit_square * C4NUM; | |||||
| // step 1 : filter transform (pre-processed offline) | // step 1 : filter transform (pre-processed offline) | ||||
| // step 2 : input transform (online) | // step 2 : input transform (online) | ||||
| for (int b = 0; b < in_batch; b++) { | for (int b = 0; b < in_batch; b++) { | ||||
| @@ -110,15 +113,16 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||||
| int out_tile_index = thread_id * TILE_NUM; | int out_tile_index = thread_id * TILE_NUM; | ||||
| int cal_num = output_count - thread_id * TILE_NUM; | int cal_num = output_count - thread_id * TILE_NUM; | ||||
| cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num; | cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num; | ||||
| WinogradInputTransform(input_data, trans_input, tmp_data, cal_num, out_tile_index, out_w_block, conv_param, | |||||
| WinogradInputTransform(input_data, trans_input + task_id * trans_input_offset, | |||||
| tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | |||||
| input_trans_func); | input_trans_func); | ||||
| // step 3 : gemm | // step 3 : gemm | ||||
| gemm_func(gemm_out, trans_input, trans_weight, nullptr, input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, | |||||
| 0, 0); | |||||
| gemm_func(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, trans_weight, nullptr, | |||||
| input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0); | |||||
| // step 4 : output transform | // step 4 : output transform | ||||
| WinogradOutputTransform(gemm_out, tmp_out_data, bias_data, cal_num, out_tile_index, out_w_block, conv_param, | |||||
| output_trans_func); | |||||
| WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data, bias_data, cal_num, out_tile_index, | |||||
| out_w_block, conv_param, output_trans_func); | |||||
| } | } | ||||
| } | } | ||||
| // get real output | // get real output | ||||
| @@ -191,20 +195,25 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||||
| float *block_unit_buffer = buffer_list[1]; | float *block_unit_buffer = buffer_list[1]; | ||||
| float *tmp_dst_buffer = buffer_list[2]; | float *tmp_dst_buffer = buffer_list[2]; | ||||
| float *nc4hw4_out = buffer_list[3]; | float *nc4hw4_out = buffer_list[3]; | ||||
| int tile_buffer_offset = TILE_NUM * input_unit_square * ic4 * C4NUM; | |||||
| int block_unit_buffer_offset = input_unit_square * C4NUM; | |||||
| int tmp_dst_buffer_offset = TILE_NUM * input_unit_square * oc4 * C4NUM; | |||||
| int input_batch = conv_param->input_batch_; | int input_batch = conv_param->input_batch_; | ||||
| for (int batch = 0; batch < input_batch; batch++) { | for (int batch = 0; batch < input_batch; batch++) { | ||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | ||||
| int start_index = thread_id * TILE_NUM; | int start_index = thread_id * TILE_NUM; | ||||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | ||||
| Conv3x3Fp32InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, | |||||
| conv_param); | |||||
| Conv3x3Fp32InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, | |||||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||||
| out_w_block, conv_param); | |||||
| gemm_func(tmp_dst_buffer, tile_buffer, transed_weight, nullptr, input_unit_square, ic4, oc4 * C4NUM, | |||||
| gemm_func(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | |||||
| transed_weight, nullptr, input_unit_square, ic4, oc4 * C4NUM, | |||||
| oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0); | oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0); | ||||
| Conv3x3Fp32OutputTransform(tmp_dst_buffer, nc4hw4_out, bias_data, start_index, real_cal_num, out_w_block, | |||||
| conv_param); | |||||
| Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out, bias_data, start_index, | |||||
| real_cal_num, out_w_block, conv_param); | |||||
| } | } | ||||
| PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); | PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); | ||||
| } | } | ||||
| @@ -31,8 +31,13 @@ void AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete | |||||
| int output_h = pooling_param->output_h_; | int output_h = pooling_param->output_h_; | ||||
| int output_batch = pooling_param->output_batch_; | int output_batch = pooling_param->output_batch_; | ||||
| int out_plane = output_w * output_h; | int out_plane = output_w * output_h; | ||||
| int16_t out_min = INT8_MIN; | |||||
| int16_t out_max = INT8_MAX; | |||||
| float input_scale = pooling_param->quant_args_[0][0].scale_; | |||||
| int input_zp = pooling_param->quant_args_[0][0].zp_; | |||||
| float output_scale = pooling_param->quant_args_[1][0].scale_; | |||||
| int output_zp = pooling_param->quant_args_[1][0].zp_; | |||||
| double real_multiplier = input_scale / output_scale; | |||||
| int8_t out_min = INT8_MIN; | |||||
| int8_t out_max = INT8_MAX; | |||||
| for (int batch = 0; batch < output_batch; batch++) { | for (int batch = 0; batch < output_batch; batch++) { | ||||
| int in_batch_offset = batch * in_h * in_w * channel; | int in_batch_offset = batch * in_h * in_w * channel; | ||||
| @@ -60,9 +65,10 @@ void AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete | |||||
| } // win_w loop | } // win_w loop | ||||
| } // win_h loop | } // win_h loop | ||||
| int16_t tmp_out = round((float)tmp_avg / (float)real_count); | int16_t tmp_out = round((float)tmp_avg / (float)real_count); | ||||
| int16_t real_out = tmp_out < out_min ? out_min : tmp_out; | |||||
| tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); | |||||
| int8_t real_out = tmp_out < out_min ? out_min : tmp_out; | |||||
| real_out = real_out > out_max ? out_max : real_out; | real_out = real_out > out_max ? out_max : real_out; | ||||
| *(output_ptr + out_channel_offset) = (int8_t)real_out; | |||||
| *(output_ptr + out_channel_offset) = real_out; | |||||
| } // in_channel loop | } // in_channel loop | ||||
| } // out_plane loop | } // out_plane loop | ||||
| } // out_batch loop | } // out_batch loop | ||||