diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index 2549ff251b..8662894afa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/base/convolution_base.h" +#include #include "schema/model_generated.h" #include "src/kernel_factory.h" #include "include/errorcode.h" @@ -66,13 +67,14 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() { free(conv_quant_arg_->out_act_max_); conv_quant_arg_->out_act_max_ = nullptr; } - - if (conv_quant_arg_->quant_args_ != nullptr) { - for (int i = 0; i < 3; ++i) { - if (*(conv_quant_arg_->quant_args_ + i) != nullptr) { - free(*(conv_quant_arg_->quant_args_ + i)); - } - } + if (conv_quant_arg_->input_quant_args_ != nullptr) { + free(conv_quant_arg_->input_quant_args_); + } + if (conv_quant_arg_->filter_quant_args_ != nullptr) { + free(conv_quant_arg_->filter_quant_args_); + } + if (conv_quant_arg_->output_quant_args_ != nullptr) { + free(conv_quant_arg_->output_quant_args_); } } @@ -103,53 +105,218 @@ int ConvolutionBaseCPUKernel::CheckLayout(lite::tensor::Tensor *input_tensor) { return RET_OK; } -int ConvolutionBaseCPUKernel::SetQuantParam() { - ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_; - conv_quant_arg_->quant_args_ = reinterpret_cast(malloc(3 * sizeof(QuantArg *))); - if (conv_quant_arg_->quant_args_ == nullptr) { - MS_LOG(ERROR) << "malloc quant_args_ failed."; - return RET_ERROR; +int ConvolutionBaseCPUKernel::SetIfPerChannel() { + uint8_t per_channel = 0b0; + if (conv_quant_arg_->input_arg_num_ != kPerTensor) { + int in_channel = conv_param_->input_channel_; + if (conv_quant_arg_->input_arg_num_ != in_channel) { + MS_LOG(ERROR) << "input per channel quant param length is not equal to input channel."; + return RET_ERROR; + } + per_channel = per_channel | INPUT_PER_CHANNEL; + } + + if (conv_quant_arg_->filter_arg_num_ != kPerTensor) { + int filter_num = conv_param_->output_channel_; + if (conv_quant_arg_->filter_arg_num_ != filter_num) { + MS_LOG(ERROR) << "weight per channel quant param length is not equal to filter num."; + return RET_ERROR; + } + per_channel = per_channel | FILTER_PER_CHANNEL; } - // per-tensor init - for (int j = 0; j < 3; ++j) { - conv_quant_arg_->quant_args_[j] = reinterpret_cast(malloc(sizeof(QuantArg))); - if (conv_quant_arg_->quant_args_[j] == nullptr) { - MS_LOG(ERROR) << "malloc quant_args_ failed."; + + if (conv_quant_arg_->output_arg_num_ != kPerTensor) { + int out_channel = conv_param_->output_channel_; + if (conv_quant_arg_->output_arg_num_ != out_channel) { + MS_LOG(ERROR) << "output per channel quant param length is not equal to output channel."; return RET_ERROR; } + per_channel = per_channel | OUTPUT_PER_CHANNEL; + } + conv_quant_arg_->per_channel_ = per_channel; + return RET_OK; +} + +int ConvolutionBaseCPUKernel::SetIfAsymmetric() { + uint8_t asymmetric = 0b0; + auto filter_tensor = in_tensors_.at(kWeightIndex); + auto filter_ele_num = filter_tensor->ElementsNum(); + auto filter_data = reinterpret_cast(filter_tensor->Data()); + float min_value = FLT_MAX; + float max_value = -FLT_MAX; + for (int i = 0; i < filter_ele_num; ++i) { + min_value = min_value < filter_data[i] ? min_value : filter_data[i]; + max_value = max_value > filter_data[i] ? max_value : filter_data[i]; + } + if (conv_quant_arg_->filter_arg_num_ == kPerTensor) { + auto filter_zp = conv_quant_arg_->filter_quant_args_[0].zp_; + if (filter_zp == 0 && min_value >= -127 && max_value <= 127) { + asymmetric = asymmetric & FILTER_ASYMMETRIC; + } + } else { + auto filter_arg = conv_quant_arg_->filter_quant_args_; + for (int i = 0; i < conv_param_->output_channel_; ++i) { + if (filter_arg[i].zp_ == 0 && min_value >= -127 && max_value <= 127) { + asymmetric = asymmetric & FILTER_ASYMMETRIC; + } + } } + conv_quant_arg_->asymmetric_ = asymmetric; + return RET_OK; +} + +int ConvolutionBaseCPUKernel::MallocQuantParam() { + conv_quant_arg_ = &conv_param_->conv_quant_arg_; auto input_tensor = in_tensors_.at(kInputIndex); auto weight_tensor = in_tensors_.at(kWeightIndex); auto output_tensor = out_tensors_.at(kOutputIndex); - auto input_quant_arg = input_tensor->GetQuantParams().front(); - auto weight_quant_arg = weight_tensor->GetQuantParams().front(); - auto output_quant_arg = output_tensor->GetQuantParams().front(); - // input - conv_quant_arg_->quant_args_[0][0].zp_ = input_quant_arg.zeroPoint; - conv_quant_arg_->quant_args_[0][0].scale_ = input_quant_arg.scale; - // weight - conv_quant_arg_->quant_args_[1][0].zp_ = weight_quant_arg.zeroPoint; - conv_quant_arg_->quant_args_[1][0].scale_ = weight_quant_arg.scale; - // output - conv_quant_arg_->quant_args_[2][0].zp_ = output_quant_arg.zeroPoint; - conv_quant_arg_->quant_args_[2][0].scale_ = output_quant_arg.scale; - - conv_quant_arg_->real_multiplier_ = reinterpret_cast(malloc(sizeof(double))); - conv_quant_arg_->left_shift_ = reinterpret_cast(malloc(sizeof(int32_t))); - conv_quant_arg_->right_shift_ = reinterpret_cast(malloc(sizeof(int32_t))); - conv_quant_arg_->quant_multiplier_ = reinterpret_cast(malloc(sizeof(int32_t))); + size_t input_arg_num = input_tensor->GetQuantParams().size(); + size_t filter_arg_num = weight_tensor->GetQuantParams().size(); + size_t output_arg_num = output_tensor->GetQuantParams().size(); + conv_quant_arg_->input_arg_num_ = input_arg_num; + conv_quant_arg_->filter_arg_num_ = filter_arg_num; + conv_quant_arg_->output_arg_num_ = output_arg_num; + + conv_quant_arg_->input_quant_args_ = reinterpret_cast(malloc(input_arg_num * sizeof(QuantArg))); + if (conv_quant_arg_->input_quant_args_ == nullptr) { + MS_LOG(ERROR) << "malloc input_quant_args_ failed."; + return RET_ERROR; + } + conv_quant_arg_->filter_quant_args_ = reinterpret_cast(malloc(filter_arg_num * sizeof(QuantArg))); + if (conv_quant_arg_->filter_quant_args_ == nullptr) { + MS_LOG(ERROR) << "malloc filter_quant_args_ failed."; + return RET_ERROR; + } + conv_quant_arg_->output_quant_args_ = reinterpret_cast(malloc(output_arg_num * sizeof(QuantArg))); + if (conv_quant_arg_->output_quant_args_ == nullptr) { + MS_LOG(ERROR) << "malloc output_quant_args_ failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionBaseCPUKernel::SetInputTensorQuantParam() { + auto input_tensor = in_tensors_.at(kInputIndex); + auto in_arg_num = conv_quant_arg_->input_arg_num_; + if (in_arg_num == kPerTensor) { + auto input_quant_arg = input_tensor->GetQuantParams().front(); + conv_quant_arg_->input_quant_args_[0].zp_ = input_quant_arg.zeroPoint; + conv_quant_arg_->input_quant_args_[0].scale_ = input_quant_arg.scale; + } else { + // per channel + MS_LOG(ERROR) << "Not Support Per Channel for input now."; + return RET_ERROR; + // auto input_quant_arg = input_tensor->GetQuantParams(); + // for (int i = 0; i < in_arg_num; ++i) { + // conv_quant_arg_->input_quant_args_[i].zp_ = input_quant_arg[i].zeroPoint; + // conv_quant_arg_->input_quant_args_[i].scale_ = input_quant_arg[i].scale; + // } + } + return RET_OK; +} + +int ConvolutionBaseCPUKernel::SetFilterTensorQuantParam() { + auto weight_tensor = in_tensors_.at(kWeightIndex); + auto weight_arg_num = conv_quant_arg_->filter_arg_num_; + if (weight_arg_num == kPerTensor) { + auto weight_quant_arg = weight_tensor->GetQuantParams().front(); + conv_quant_arg_->filter_quant_args_[0].zp_ = weight_quant_arg.zeroPoint; + conv_quant_arg_->filter_quant_args_[0].scale_ = weight_quant_arg.scale; + } else { + auto weight_quant_arg = weight_tensor->GetQuantParams(); + for (int i = 0; i < weight_arg_num; ++i) { + conv_quant_arg_->filter_quant_args_[i].zp_ = weight_quant_arg[i].zeroPoint; + conv_quant_arg_->filter_quant_args_[i].scale_ = weight_quant_arg[i].scale; + } + } + return RET_OK; +} + +int ConvolutionBaseCPUKernel::SetOutputTensorQuantParam() { + auto output_tensor = out_tensors_.at(kOutputIndex); + auto out_arg_num = conv_quant_arg_->output_arg_num_; + if (out_arg_num == kPerTensor) { + auto output_quant_arg = output_tensor->GetQuantParams().front(); + conv_quant_arg_->output_quant_args_[0].zp_ = output_quant_arg.zeroPoint; + conv_quant_arg_->output_quant_args_[0].scale_ = output_quant_arg.scale; + } else { + MS_LOG(ERROR) << "Not Support Per Channel for input now."; + return RET_ERROR; + // auto output_quant_arg = output_tensor->GetQuantParams(); + // for (int i = 0; i < out_arg_num; ++i) { + // conv_quant_arg_->output_quant_args_[i].zp_ = output_quant_arg[i].zeroPoint; + // conv_quant_arg_->output_quant_args_[i].scale_ = output_quant_arg[i].scale; + // } + } + return RET_OK; +} + +int ConvolutionBaseCPUKernel::SetQuantMultiplier() { + // now only support weight tensor is per channel, others are per tensor. + int weight_arg_num = kPerTensor; + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + weight_arg_num = conv_quant_arg_->filter_arg_num_; + } + conv_quant_arg_->real_multiplier_ = reinterpret_cast(malloc(weight_arg_num * sizeof(double))); + conv_quant_arg_->left_shift_ = reinterpret_cast(malloc(weight_arg_num * sizeof(int32_t))); + conv_quant_arg_->right_shift_ = reinterpret_cast(malloc(weight_arg_num * sizeof(int32_t))); + conv_quant_arg_->quant_multiplier_ = reinterpret_cast(malloc(weight_arg_num * sizeof(int32_t))); conv_quant_arg_->out_act_min_ = reinterpret_cast(malloc(sizeof(int32_t))); conv_quant_arg_->out_act_max_ = reinterpret_cast(malloc(sizeof(int32_t))); - double real_multiplier = weight_quant_arg.scale * input_quant_arg.scale / output_quant_arg.scale; - conv_quant_arg_->real_multiplier_[0] = real_multiplier; - QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[0], &conv_quant_arg_->left_shift_[0], - &conv_quant_arg_->right_shift_[0]); + for (int i = 0; i < weight_arg_num; ++i) { + double real_multiplier = conv_quant_arg_->filter_quant_args_[i].scale_ * + conv_quant_arg_->input_quant_args_[0].scale_ / + conv_quant_arg_->output_quant_args_[0].scale_; + conv_quant_arg_->real_multiplier_[i] = real_multiplier; + QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], + &conv_quant_arg_->right_shift_[i]); + } + return RET_OK; +} +int ConvolutionBaseCPUKernel::SetQuantParam() { + auto ret = MallocQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Malloc quant param failed."; + return ret; + } + ret = SetInputTensorQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set Input Tensor Quant Param Failed."; + return ret; + } + ret = SetFilterTensorQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set Filter Tensor Quant Param Failed."; + return ret; + } + ret = SetOutputTensorQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed."; + return ret; + } + ret = SetQuantMultiplier(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set Quant Multiplier Failed."; + return ret; + } + // now only consider per tensor for output CalculateActivationRangeQuantized( - conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.quant_args_[2][0].zp_, - conv_param_->conv_quant_arg_.quant_args_[2][0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0], + conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param_->conv_quant_arg_.output_quant_args_[0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0], &conv_param_->conv_quant_arg_.out_act_max_[0]); + + ret = SetIfPerChannel(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set if per tensor channel failed."; + return ret; + } + ret = SetIfAsymmetric(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set if per asymmetric failed."; + return ret; + } return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index 1394293bed..c4adf97885 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -32,6 +32,7 @@ using mindspore::lite::Context; using mindspore::schema::PadMode; using mindspore::schema::QuantType; +static constexpr int kPerTensor = 1; namespace mindspore::kernel { class ConvolutionBaseCPUKernel : public LiteKernel { @@ -49,7 +50,14 @@ class ConvolutionBaseCPUKernel : public LiteKernel { int ReSize() override { return 0; } int Run() override { return 0; } virtual int CheckLayout(lite::tensor::Tensor *input_tensor); + int SetIfAsymmetric(); + int SetIfPerChannel(); + int MallocQuantParam(); int SetQuantParam(); + int SetInputTensorQuantParam(); + int SetFilterTensorQuantParam(); + int SetOutputTensorQuantParam(); + int SetQuantMultiplier(); void FreeQuantParam(); protected: @@ -59,9 +67,9 @@ class ConvolutionBaseCPUKernel : public LiteKernel { void *nhwc4_input_ = nullptr; const Context *ctx_; ConvParameter *conv_param_; + ConvQuantArg *conv_quant_arg_; LayoutConvertor convert_func_; }; -bool CheckSupportFP16(); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index 99ecb17c77..e248e107f2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -69,8 +69,8 @@ int ConvolutionInt8CPUKernel::InitWeightBias() { int kernel_plane = kernel_h * kernel_w; int plane_c4 = UP_DIV(kernel_plane, C4NUM); int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * plane_c4 * C4NUM; - int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_; - int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_; + auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_; + int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_; // init weight auto origin_weight = reinterpret_cast(in_tensors_.at(kWeightIndex)->Data()); @@ -99,8 +99,14 @@ int ConvolutionInt8CPUKernel::InitWeightBias() { } auto *bias_data = reinterpret_cast(bias_data_); int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM; - for (int i = 0; i < out_channel; i++) { - bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } + } else { + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } } free(weight_sum); return RET_OK; @@ -125,7 +131,13 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); /*=============================input_sum_============================*/ - input_sum_ = reinterpret_cast(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); + size_t input_sum_size; + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t); + } else { + input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t); + } + input_sum_ = reinterpret_cast(malloc(input_sum_size)); if (input_sum_ == nullptr) { MS_LOG(ERROR) << "malloc input_sum_ failed."; return RET_ERROR; @@ -168,8 +180,8 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { int oc4 = UP_DIV(out_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; - int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_; - int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_; + auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_; + int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_; // init weight auto origin_weight = reinterpret_cast(in_tensors_.at(kWeightIndex)->Data()); @@ -178,9 +190,9 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { MS_LOG(ERROR) << "malloc packed_weight_ failed."; return RET_ERROR; } - memset(packed_weight_, filter_zp, pack_weight_size); + memset(packed_weight_, 0, pack_weight_size); auto *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); - for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane; + for (int i = 0; i < out_channel; i++) weight_sum[i] = 0; PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum); // init bias @@ -198,8 +210,14 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { } auto *bias_data = reinterpret_cast(bias_data_); int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM; - for (int i = 0; i < out_channel; i++) { - bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } + } else { + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } } free(weight_sum); return RET_OK; @@ -223,7 +241,13 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); /*=============================input_sum_============================*/ - input_sum_ = reinterpret_cast(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); + size_t input_sum_size; + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t); + } else { + input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t); + } + input_sum_ = reinterpret_cast(malloc(input_sum_size)); if (input_sum_ == nullptr) { MS_LOG(ERROR) << "malloc input_sum_ failed."; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c index 56f9ed3fcf..1977f49e7c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c @@ -77,7 +77,7 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], - conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); dst_kernel += sliding->block_channel_; @@ -168,15 +168,15 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t), sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], - conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], - conv_param->conv_quant_arg_.out_act_max_[0]); + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); #else DepthwiseCenterInt8( out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], - conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); #endif } @@ -333,7 +333,7 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in DeconvDepthwisePostFuncInt8( dst_data, output_buffer, bias, sliding->block_channel_, conv_param, conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], - conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); } // output C4 loop src += sliding->in_step_; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c index b1ed0a8b98..b88a708240 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.c @@ -22,10 +22,10 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, ConvParameter *conv_param) { - int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0]; - int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0]; - int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; - int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int32_t *shift_before = conv_param->conv_quant_arg_.left_shift_; + int32_t *shift_after = conv_param->conv_quant_arg_.right_shift_; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; #ifdef __aarch64__ @@ -63,14 +63,49 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in } // in c4num loop } // ic4 loop } // kernel_plane loop - tmp_dst[dst_tile_offset] -= input_sum[n]; - int result = tmp_dst[dst_tile_offset] + bias[oc]; - result = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after); - result += out_zp; - result = result > act_min ? result : act_min; - result = result < act_max ? result : act_max; - dst[dst_tile_offset] = (int8_t)result; + if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), + -shift_after[oc]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } else if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + !(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]), + -shift_after[0]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + !(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + tmp_dst[dst_tile_offset] -= input_sum[n]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]), + -shift_after[0]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), + -shift_after[oc]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } } // tile_num loop } // output_channel loop #endif @@ -79,10 +114,10 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, ConvParameter *conv_param, GEMM_FUNC gemm_func) { - int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0]; - int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0]; - int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; - int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int32_t *shift_before = conv_param->conv_quant_arg_.left_shift_; + int32_t *shift_after = conv_param->conv_quant_arg_.right_shift_; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; if (gemm_func != NULL) { @@ -113,14 +148,49 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const } // in c4num loop } // ic4 loop } // kernel_plane loop - tmp_dst[dst_tile_offset] -= input_sum[n]; - int result = tmp_dst[dst_tile_offset] + bias[oc]; - result = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after); - result += out_zp; - result = result > act_min ? result : act_min; - result = result < act_max ? result : act_max; - dst[dst_tile_offset] = (int8_t)result; + if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), + -shift_after[oc]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } else if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + !(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]), + -shift_after[0]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + !(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + tmp_dst[dst_tile_offset] -= input_sum[n]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]), + -shift_after[0]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), + -shift_after[oc]); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } } // tile_num loop } // output_channel loop } @@ -182,7 +252,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c int out_h = conv_param->output_h_; int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; - int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; int tile_n = conv_param->tile_num_; int thread_count = conv_param->thread_num_; @@ -238,7 +308,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight int out_h = conv_param->output_h_; int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; - int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; int tile_n = conv_param->tile_num_; int thread_count = conv_param->thread_num_; int output_count = out_h * out_w; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c index 691ae7ba0c..e46d9e9047 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c @@ -19,8 +19,8 @@ int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, ConvParameter *conv_param) { - MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.quant_args_[0][0].zp_, - conv_param->conv_quant_arg_.quant_args_[1][0].zp_); + MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.input_quant_args_[0].zp_, + conv_param->conv_quant_arg_.filter_quant_args_[0].zp_); return NNACL_OK; } @@ -65,7 +65,7 @@ int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8), conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], - conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); return NNACL_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c index b9341e3ad2..dbc713dde6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c @@ -115,7 +115,6 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p int oc4 = UP_DIV(out_channel, C4NUM); int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; - int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; int unit_size = C4NUM * C4NUM; int block_size = pack_weight_size / oc4; @@ -143,7 +142,7 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p if (packed_data_ptr[0] == -128) { packed_data_ptr[0] = -127; } - weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - filter_zp); + weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0]); } } // kernel block loop } // inchannel block loop @@ -241,7 +240,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real int32_t *input_sum, ConvParameter *conv_param) { // input format : nhwc int tile_num = conv_param->tile_num_; - int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_; int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; int stride_h = conv_param->stride_h_; @@ -292,7 +291,18 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real } // channel_block loop } // kernel_w loop } // kernel_h loop - input_sum[i] = input_accumulator * filter_zp; + if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) { + return; + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + int cal_num_offset = i * conv_param->output_channel_; + for (int l = 0; l < conv_param->output_channel_; ++l) { + input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_; + } + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + !(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + input_sum[i] = input_accumulator * filter_arg[0].zp_; + } } // tile num loop } @@ -300,7 +310,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r int32_t *input_sum, ConvParameter *conv_param) { // input format : nhwc int tile_num = conv_param->tile_num_; - int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_; int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; int stride_h = conv_param->stride_h_; @@ -348,13 +358,23 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM; for (int c = 0; c < ic4; c++) { int ic4_offset = block_offset + c * tile_num * C4NUM; - input_accumulator += (packed_input + ic4_offset)[0]; - input_accumulator += (packed_input + ic4_offset)[1]; - input_accumulator += (packed_input + ic4_offset)[2]; - input_accumulator += (packed_input + ic4_offset)[3]; + for (int k = 0; k < C4NUM; ++k) { + input_accumulator += (packed_input + ic4_offset)[k]; + } + } + } + if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) { + return; + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + int cal_num_offset = i * conv_param->output_channel_; + for (int l = 0; l < conv_param->output_channel_; ++l) { + input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_; } + } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && + !(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + input_sum[i] = input_accumulator * filter_arg[0].zp_; } - input_sum[i] = input_accumulator * filter_zp; } // tile num loop } @@ -387,7 +407,7 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight int input_channel = conv_param->input_channel_; int ic8 = UP_DIV(input_channel, C8NUM); int output_channel = conv_param->output_channel_; - int filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_; int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; for (int k = 0; k < kernel_plane; k++) { @@ -401,7 +421,7 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight int c8_block_rem = i % C8NUM; int src_ic_offset = src_oc_offset + i; int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; - (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp); + (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp[o].zp_); } } } @@ -806,7 +826,7 @@ void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) { } void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { - int input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); int unit = conv_param->input_h_ * conv_param->input_w_; @@ -824,7 +844,7 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter } void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, const ConvParameter *conv_param) { - int weight_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int weight_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; int unit = conv_param->kernel_h_ * conv_param->kernel_w_; for (int c = 0; c < conv_param->output_channel_; c++) { int c4_block_num = c / C4NUM; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h index b4ac17b4cf..64e6f534cd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h @@ -17,25 +17,37 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_ -#include #include -#include #include #include "nnacl/op_base.h" +#define INPUT_ASYMMETRIC 0b001 +#define FILTER_ASYMMETRIC 0b010 +#define OUTPUT_ASYMMETRIC 0b100 +#define INPUT_PER_CHANNEL 0b001 +#define FILTER_PER_CHANNEL 0b010 +#define OUTPUT_PER_CHANNEL 0b100 + typedef struct QuantArg { double scale_; int32_t zp_; } QuantArg; typedef struct ConvQuantArg { - QuantArg **quant_args_; + QuantArg *input_quant_args_; + QuantArg *filter_quant_args_; + QuantArg *output_quant_args_; double *real_multiplier_; int32_t *left_shift_; int32_t *right_shift_; int32_t *quant_multiplier_; int32_t *out_act_min_; int32_t *out_act_max_; + size_t input_arg_num_; + size_t filter_arg_num_; + size_t output_arg_num_; + uint8_t asymmetric_; + uint8_t per_channel_; } ConvQuantArg; typedef struct ConcatQuantArg { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.c index 85d4f7e95d..87e330e446 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.c @@ -854,7 +854,7 @@ void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, int pad_w = conv_param->pad_w_; int pad_h = conv_param->pad_h_; ConvQuantArg quant_arg = conv_param->conv_quant_arg_; - int input_zp = quant_arg.quant_args_[0][0].zp_; + int input_zp = quant_arg.input_quant_args_[0].zp_; int ic8 = UP_DIV(input_channel, C8NUM); int input_unit = 4; @@ -1155,11 +1155,11 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh } void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, - bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param) { - int left_shift = conv_param->conv_quant_arg_.left_shift_[0]; - int right_shift = conv_param->conv_quant_arg_.right_shift_[0]; - int quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; - int output_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; @@ -1202,12 +1202,21 @@ void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, i int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); - int32x4_t out_multiplier = vdupq_n_s32(quant_multiplier); + int32x4_t out_multiplier; + int32x4_t ls; + int32x4_t rs; + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + out_multiplier = vld1q_s32(quant_multiplier); + ls = vld1q_s32(left_shift); + rs = vld1q_s32(right_shift); + } else { + out_multiplier = vdupq_n_s32(quant_multiplier); + ls = vdupq_n_s32(left_shift); + rs = vdupq_n_s32(right_shift); + } int32x4_t out_zp = vdupq_n_s32(output_zp); int32x4_t output_min = vdupq_n_s32(out_min); int32x4_t output_max = vdupq_n_s32(out_max); - int32x4_t ls = vdupq_n_s32(left_shift); - int32x4_t rs = vdupq_n_s32(right_shift); d00 = vqshlq_s32(d00, ls); d00 = vqrdmulhq_s32(d00, out_multiplier); @@ -1261,78 +1270,166 @@ void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, i } } #else - for (int i = 0; i < C4NUM; i++) { - const int32_t *local_ptr = gemm_out + i; - const int32_t *bias_ptr = bias_data + i; - - int32_t s00 = local_ptr[0]; - int32_t s01 = (local_ptr + 4)[0]; - int32_t s02 = (local_ptr + 8)[0]; - int32_t s03 = (local_ptr + 12)[0]; - - int32_t s10 = (local_ptr + 16)[0]; - int32_t s11 = (local_ptr + 20)[0]; - int32_t s12 = (local_ptr + 24)[0]; - int32_t s13 = (local_ptr + 28)[0]; - - int32_t s20 = (local_ptr + 32)[0]; - int32_t s21 = (local_ptr + 36)[0]; - int32_t s22 = (local_ptr + 40)[0]; - int32_t s23 = (local_ptr + 44)[0]; - - int32_t s30 = (local_ptr + 48)[0]; - int32_t s31 = (local_ptr + 52)[0]; - int32_t s32 = (local_ptr + 56)[0]; - int32_t s33 = (local_ptr + 60)[0]; - - int32_t t00 = (s00 + s10 + s20) / 2; - int32_t t01 = (s01 + s11 + s21) / 2; - int32_t t02 = (s02 + s12 + s22) / 2; - int32_t t03 = (s03 + s13 + s23) / 2; - - int32_t t10 = (s10 - s20 - s30) / 2; - int32_t t11 = (s11 - s21 - s31) / 2; - int32_t t12 = (s12 - s22 - s32) / 2; - int32_t t13 = (s13 - s23 - s33) / 2; - - int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; - int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; - - int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; - int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; - - d00 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); - d00 += output_zp; - d00 = d00 > out_min ? d00 : out_min; - d00 = d00 < out_max ? d00 : out_max; - - d01 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); - d01 += output_zp; - d01 = d01 > out_min ? d01 : out_min; - d01 = d01 < out_max ? d01 : out_max; - - d10 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); - d10 += output_zp; - d10 = d10 > out_min ? d10 : out_min; - d10 = d10 < out_max ? d10 : out_max; - - d11 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); - d11 += output_zp; - d11 = d11 > out_min ? d11 : out_min; - d11 = d11 < out_max ? d11 : out_max; - - (output_data + i)[0] = (int8_t)d00; - if (w_not_bound) { - (output_data + i + C4NUM)[0] = (int8_t)d01; + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + int oc_index = oc_start + i; + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } } - if (h_not_bound) { - (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + + } else { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; if (w_not_bound) { - (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } } } } @@ -1364,7 +1461,8 @@ void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, cons int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; - Conv3x3Uint8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, conv_param); + Conv3x3Uint8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, + conv_param); } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.h index 173dc30889..31db84742c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.h @@ -65,7 +65,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh int kernel_plane); void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, - bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param); + bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param); void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);