2. fix bug of int8 conv per layer param 3. Add return error when set quant param failedtags/v0.7.0-beta
| @@ -141,23 +141,23 @@ int ConvolutionBaseCPUKernel::SetIfAsymmetric() { | |||||
| uint8_t asymmetric = 0b0; | uint8_t asymmetric = 0b0; | ||||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | auto filter_tensor = in_tensors_.at(kWeightIndex); | ||||
| auto filter_ele_num = filter_tensor->ElementsNum(); | auto filter_ele_num = filter_tensor->ElementsNum(); | ||||
| auto filter_data = reinterpret_cast<float *>(filter_tensor->Data()); | |||||
| float min_value = FLT_MAX; | |||||
| float max_value = -FLT_MAX; | |||||
| auto filter_data = reinterpret_cast<int8_t *>(filter_tensor->Data()); | |||||
| int min_value = INT8_MAX; | |||||
| int max_value = INT8_MIN; | |||||
| for (int i = 0; i < filter_ele_num; ++i) { | for (int i = 0; i < filter_ele_num; ++i) { | ||||
| min_value = min_value < filter_data[i] ? min_value : filter_data[i]; | min_value = min_value < filter_data[i] ? min_value : filter_data[i]; | ||||
| max_value = max_value > filter_data[i] ? max_value : filter_data[i]; | max_value = max_value > filter_data[i] ? max_value : filter_data[i]; | ||||
| } | } | ||||
| if (conv_quant_arg_->filter_arg_num_ == kPerTensor) { | if (conv_quant_arg_->filter_arg_num_ == kPerTensor) { | ||||
| auto filter_zp = conv_quant_arg_->filter_quant_args_[0].zp_; | auto filter_zp = conv_quant_arg_->filter_quant_args_[0].zp_; | ||||
| if (filter_zp == 0 && min_value >= -127 && max_value <= 127) { | |||||
| asymmetric = asymmetric & FILTER_ASYMMETRIC; | |||||
| if (filter_zp != 0 && min_value >= -128 && max_value <= 127) { | |||||
| asymmetric = asymmetric | FILTER_ASYMMETRIC; | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto filter_arg = conv_quant_arg_->filter_quant_args_; | auto filter_arg = conv_quant_arg_->filter_quant_args_; | ||||
| for (int i = 0; i < conv_param_->output_channel_; ++i) { | for (int i = 0; i < conv_param_->output_channel_; ++i) { | ||||
| if (filter_arg[i].zp_ == 0 && min_value >= -127 && max_value <= 127) { | |||||
| asymmetric = asymmetric & FILTER_ASYMMETRIC; | |||||
| if (filter_arg[i].zp_ != 0 && min_value >= -128 && max_value <= 127) { | |||||
| asymmetric = asymmetric | FILTER_ASYMMETRIC; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -282,6 +282,39 @@ int Convolution3x3FP16CPUKernel::Run() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // get real output | |||||
| // todo | |||||
| int out_w_block = UP_DIV(conv_param_->output_w_, C4NUM); | |||||
| int out_h_block = UP_DIV(conv_param_->output_h_, C4NUM); | |||||
| int oc8 = UP_DIV(conv_param_->output_channel_, C8NUM); | |||||
| bool relu = conv_param_->is_relu_; | |||||
| bool relu6 = conv_param_->is_relu6_; | |||||
| for (int batch = 0; batch < conv_param_->output_batch_; batch++) { | |||||
| int tmp_out_batch_offset = | |||||
| batch * oc8 * C8NUM * out_w_block * out_h_block * conv_param_->output_unit_ * conv_param_->output_unit_; | |||||
| int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_; | |||||
| const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset; | |||||
| float16_t *batch_out = fp16_out_ + ro_batch_size; | |||||
| for (int h = 0; h < conv_param_->output_h_; h++) { | |||||
| for (int w = 0; w < conv_param_->output_w_; w++) { | |||||
| for (int c = 0; c < conv_param_->output_channel_; c++) { | |||||
| int oc8_block = c / C8NUM; | |||||
| int oc8_res = c % C8NUM; | |||||
| int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + | |||||
| C8NUM * (h * out_w_block * conv_param_->output_unit_ + w) + oc8_res; | |||||
| int dst_offset = (h * conv_param_->output_w_ + w) * conv_param_->output_channel_ + c; | |||||
| (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; | |||||
| if (relu) { | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||||
| } else if (relu6) { | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] > 6 ? 6 : (batch_out + dst_offset)[0]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // cast fp16 out to fp32 data | // cast fp16 out to fp32 data | ||||
| auto out_tensor = out_tensors_.at(kOutputIndex); | auto out_tensor = out_tensors_.at(kOutputIndex); | ||||
| auto out_ele_num = out_tensor->ElementsNum(); | auto out_ele_num = out_tensor->ElementsNum(); | ||||
| @@ -263,6 +263,13 @@ int ConvolutionSWFP16CPUKernel::Run() { | |||||
| auto out_tensor = out_tensors_.at(kOutputIndex); | auto out_tensor = out_tensors_.at(kOutputIndex); | ||||
| auto out_ele_num = out_tensor->ElementsNum(); | auto out_ele_num = out_tensor->ElementsNum(); | ||||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | ||||
| // output nhwc4 | |||||
| int oc4_res = conv_param_->output_channel_ % C4NUM; | |||||
| if (oc4_res != 0) { | |||||
| PackNHWC4ToNHWCFp16(reinterpret_cast<const void *>(tmp_output_block_), reinterpret_cast<void *>(fp16_out_), | |||||
| conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_, | |||||
| conv_param_->output_channel_); | |||||
| } | |||||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -206,6 +206,14 @@ int ConvolutionSWCPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "conv error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "conv error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // output nhwc4 | |||||
| auto out_tensor = out_tensors_.front(); | |||||
| auto out_data = reinterpret_cast<float *>(out_tensor->Data()); | |||||
| int oc4_res = conv_param_->output_channel_ % C4NUM; | |||||
| if (oc4_res != 0) { | |||||
| PackNHWC4ToNHWCFp32(tmp_output_block_, out_data, conv_param_->output_batch_, | |||||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -170,7 +170,11 @@ int Convolution3x3Int8CPUKernel::Init() { | |||||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | MS_LOG(ERROR) << "ConvolutionBase init failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| SetQuantParam(); | |||||
| ret = SetQuantParam(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set quant param failed."; | |||||
| return ret; | |||||
| } | |||||
| ret = InitWeightBias(); | ret = InitWeightBias(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init weight bias failed."; | MS_LOG(ERROR) << "Init weight bias failed."; | ||||
| @@ -249,6 +253,11 @@ int Convolution3x3Int8CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "conv3x3 int8 error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "conv3x3 int8 error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // get real output | |||||
| auto out_tensor = out_tensors_.front(); | |||||
| auto out_data = reinterpret_cast<int8_t *>(out_tensor->Data()); | |||||
| PackNC4HW4ToNHWCInt8(tmp_out_, out_data, conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_, | |||||
| conv_param_->output_channel_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -111,10 +111,14 @@ int ConvolutionDepthwiseInt8CPUKernel::Init() { | |||||
| InitSlidingParamConvDw(sliding, conv_param_, C4NUM); | InitSlidingParamConvDw(sliding, conv_param_, C4NUM); | ||||
| // init quant param | // init quant param | ||||
| ConvolutionBaseCPUKernel::SetQuantParam(); | |||||
| auto ret = ConvolutionBaseCPUKernel::SetQuantParam(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set quant param failed."; | |||||
| return ret; | |||||
| } | |||||
| // init weight and bias | // init weight and bias | ||||
| auto ret = InitWeightBias(); | |||||
| ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; | MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; | ||||
| return ret; | return ret; | ||||
| @@ -305,7 +305,11 @@ int ConvolutionInt8CPUKernel::Init() { | |||||
| // config input output | // config input output | ||||
| ConfigInputOutput(); | ConfigInputOutput(); | ||||
| CheckSupportOptimize(); | CheckSupportOptimize(); | ||||
| SetQuantParam(); | |||||
| ret = SetQuantParam(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set quant param failed."; | |||||
| return ret; | |||||
| } | |||||
| // init for opt | // init for opt | ||||
| if (support_optimize_) { | if (support_optimize_) { | ||||
| ret = InitOpt(); | ret = InitOpt(); | ||||
| @@ -148,10 +148,14 @@ int DeconvolutionDepthwiseInt8CPUKernel::Init() { | |||||
| ConvolutionBaseCPUKernel::Init(); | ConvolutionBaseCPUKernel::Init(); | ||||
| // init quant param | // init quant param | ||||
| ConvolutionBaseCPUKernel::SetQuantParam(); | |||||
| auto ret = ConvolutionBaseCPUKernel::SetQuantParam(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set quant param failed."; | |||||
| return ret; | |||||
| } | |||||
| // init weight and bias | // init weight and bias | ||||
| auto ret = InitWeightBias(); | |||||
| ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Deconv Depthwise int8 InitWeightBias error!"; | MS_LOG(ERROR) << "Deconv Depthwise int8 InitWeightBias error!"; | ||||
| return ret; | return ret; | ||||
| @@ -307,11 +307,6 @@ void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, con | |||||
| src += slidingWindow_param->in_step_; | src += slidingWindow_param->in_step_; | ||||
| dst += slidingWindow_param->out_step_; | dst += slidingWindow_param->out_step_; | ||||
| } // batch loop | } // batch loop | ||||
| // output nhwc4 | |||||
| if (oc4_res != 0) { | |||||
| PackNHWC4ToNHWCFp16((const void *)tmp_out_block, (void *)output_data, conv_param->output_batch_, | |||||
| conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); | |||||
| } | |||||
| } | } | ||||
| // fp16 convolution common (im2col+gemm) | // fp16 convolution common (im2col+gemm) | ||||
| @@ -381,11 +376,6 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||||
| int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); | int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); | ||||
| int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); | int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); | ||||
| int output_batch = conv_param->output_batch_; | |||||
| int output_channel = conv_param->output_channel_; | |||||
| int output_w = conv_param->output_w_; | |||||
| int output_h = conv_param->output_h_; | |||||
| int out_w_block = UP_DIV(conv_param->output_w_, C4NUM); | int out_w_block = UP_DIV(conv_param->output_w_, C4NUM); | ||||
| int out_h_block = UP_DIV(conv_param->output_h_, C4NUM); | int out_h_block = UP_DIV(conv_param->output_h_, C4NUM); | ||||
| int output_count = out_w_block * out_h_block; | int output_count = out_w_block * out_h_block; | ||||
| @@ -414,33 +404,4 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | bias_data, start_index, real_cal_num, out_w_block, conv_param); | ||||
| } | } | ||||
| } | } | ||||
| // get real output | |||||
| // todo | |||||
| bool relu = conv_param->is_relu_; | |||||
| bool relu6 = conv_param->is_relu6_; | |||||
| for (int batch = 0; batch < output_batch; batch++) { | |||||
| int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit; | |||||
| int ro_batch_size = batch * output_channel * output_h * output_w; | |||||
| const float16_t *batch_tmp_out = tmp_out + tmp_out_batch_offset; | |||||
| float16_t *batch_out = output_data + ro_batch_size; | |||||
| for (int h = 0; h < output_h; h++) { | |||||
| for (int w = 0; w < output_w; w++) { | |||||
| for (int c = 0; c < output_channel; c++) { | |||||
| int oc8_block = c / C8NUM; | |||||
| int oc8_res = c % C8NUM; | |||||
| int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + | |||||
| C8NUM * (h * out_w_block * output_unit + w) + oc8_res; | |||||
| int dst_offset = (h * output_w + w) * output_channel + c; | |||||
| (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; | |||||
| if (relu) { | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||||
| } else if (relu6) { | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] > 6 ? 6 : (batch_out + dst_offset)[0]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| @@ -182,11 +182,6 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float | |||||
| src += slidingWindow_param->in_step_; | src += slidingWindow_param->in_step_; | ||||
| dst += slidingWindow_param->out_step_; | dst += slidingWindow_param->out_step_; | ||||
| } // batch loop | } // batch loop | ||||
| // output nhwc4 | |||||
| if (oc4_res != 0) { | |||||
| PackNHWC4ToNHWCFp32(tmp_out_block, output_data, conv_param->output_batch_, | |||||
| conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); | |||||
| } | |||||
| } | } | ||||
| // fp32 conv common | // fp32 conv common | ||||
| @@ -264,7 +264,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c | |||||
| int packed_input_size = output_tile_count * tile_n * unit_size; | int packed_input_size = output_tile_count * tile_n * unit_size; | ||||
| for (int b = 0; b < in_batch; b++) { | for (int b = 0; b < in_batch; b++) { | ||||
| int in_batch_offset = b * in_channel * in_h * in_w; | |||||
| int in_batch_offset = b * ic4 * C4NUM * in_h * in_w; | |||||
| int out_batch_offset = b * out_channel * out_h * out_w; | int out_batch_offset = b * out_channel * out_h * out_w; | ||||
| int gemm_in_batch_offset = b * packed_input_size; | int gemm_in_batch_offset = b * packed_input_size; | ||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | ||||
| @@ -319,7 +319,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||||
| int packed_input_size = output_tile_count * tile_n * unit_size; | int packed_input_size = output_tile_count * tile_n * unit_size; | ||||
| for (int b = 0; b < in_batch; b++) { | for (int b = 0; b < in_batch; b++) { | ||||
| int in_batch_offset = b * in_channel * in_h * in_w; | |||||
| int in_batch_offset = b * ic4 * C4NUM * in_h * in_w; | |||||
| int out_batch_offset = b * out_channel * out_h * out_w; | int out_batch_offset = b * out_channel * out_h * out_w; | ||||
| int gemm_in_batch_offset = b * packed_input_size; | int gemm_in_batch_offset = b * packed_input_size; | ||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | ||||
| @@ -358,10 +358,7 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi | |||||
| int task_id, ConvParameter *conv_param) { | int task_id, ConvParameter *conv_param) { | ||||
| int thread_count = conv_param->thread_num_; | int thread_count = conv_param->thread_num_; | ||||
| int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); | int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); | ||||
| int output_batch = conv_param->output_batch_; | |||||
| int output_channel = conv_param->output_channel_; | int output_channel = conv_param->output_channel_; | ||||
| int output_w = conv_param->output_w_; | |||||
| int output_h = conv_param->output_h_; | |||||
| int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); | int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); | ||||
| int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | ||||
| int output_count = out_w_block * out_h_block; | int output_count = out_w_block * out_h_block; | ||||
| @@ -373,22 +370,21 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi | |||||
| int input_batch = conv_param->input_batch_; | int input_batch = conv_param->input_batch_; | ||||
| for (int batch = 0; batch < input_batch; batch++) { | for (int batch = 0; batch < input_batch; batch++) { | ||||
| int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; | |||||
| int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; | |||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | ||||
| int start_index = thread_id * TILE_NUM; | int start_index = thread_id * TILE_NUM; | ||||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | ||||
| Conv3x3Uint8InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, | |||||
| Conv3x3Uint8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | |||||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | ||||
| out_w_block, conv_param); | out_w_block, conv_param); | ||||
| Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | ||||
| transed_weight, output_channel, ic8, real_cal_num); | transed_weight, output_channel, ic8, real_cal_num); | ||||
| Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, | |||||
| real_cal_num, out_w_block, conv_param); | |||||
| Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, | |||||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||||
| } | } | ||||
| } | } | ||||
| // get real output | |||||
| PackNC4HW4ToNHWCInt8(tmp_out, output_data, output_batch, output_h * output_w, output_channel); | |||||
| } | } | ||||
| @@ -417,6 +417,12 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight | |||||
| int src_kernel_offset = k * input_channel; | int src_kernel_offset = k * input_channel; | ||||
| int dst_kernel_offset = k * C8NUM; | int dst_kernel_offset = k * C8NUM; | ||||
| for (int o = 0; o < output_channel; o++) { | for (int o = 0; o < output_channel; o++) { | ||||
| int32_t zp; | |||||
| if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | |||||
| zp = filter_zp[0].zp_; | |||||
| } else { | |||||
| zp = filter_zp[o].zp_; | |||||
| } | |||||
| int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; | int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; | ||||
| int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; | int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; | ||||
| for (int i = 0; i < input_channel; i++) { | for (int i = 0; i < input_channel; i++) { | ||||
| @@ -424,7 +430,7 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight | |||||
| int c8_block_rem = i % C8NUM; | int c8_block_rem = i % C8NUM; | ||||
| int src_ic_offset = src_oc_offset + i; | int src_ic_offset = src_oc_offset + i; | ||||
| int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; | int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; | ||||
| (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp[o].zp_); | |||||
| (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||