diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index e9ed0bac18..f37f19d1ad 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -97,6 +97,7 @@ set(CODER_OPCODERS_SRC ${MICRO_DIR}/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/int8/concat_int8_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc + ${MICRO_DIR}/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/int8/matmul_int8_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc diff --git a/mindspore/lite/micro/cmake/package_nnacl.cmake b/mindspore/lite/micro/cmake/package_nnacl.cmake index f3f26fc46e..ee24e52449 100644 --- a/mindspore/lite/micro/cmake/package_nnacl.cmake +++ b/mindspore/lite/micro/cmake/package_nnacl.cmake @@ -9,6 +9,8 @@ file(GLOB KERNEL_SRC if(MICRO_BUILD_ARM64) file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm64/*.S) + file(GLOB OPT_SRC ${NNACL_DIR}/assembly/opt/*.S) + list(APPEND ASSEMBLY_SRC ${OPT_SRC}) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) endif() diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc index 1f80171690..9d6a27c2aa 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc @@ -127,7 +127,8 @@ int Conv2D3x3Int8Coder::Prepare(CoderContext *const context) { } int Conv2D3x3Int8Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/int8/conv_int8.h"}, {"pack_int8.c", "conv_int8.c", "fixed_point.c"}); + Collect(context, {"nnacl/int8/conv_int8.h", "nnacl/int8/conv3x3_int8.h"}, + {"pack_int8.c", "conv_int8.c", "conv3x3_int8.c", "fixed_point.c"}); nnacl::NNaclInt8Serializer code; code.precision(kPrecision); // call the op function diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc index e1059b886e..e6c7913b5a 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc @@ -16,133 +16,27 @@ #include "coder/opcoders/nnacl/int8/fullconnection_int8_coder.h" #include "nnacl/int8/matmul_int8.h" -#include "coder/opcoders/file_collector.h" #include "coder/log.h" -#include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" using mindspore::schema::PrimitiveType_FullConnection; namespace mindspore::lite::micro::nnacl { - -FullConnectionInt8Coder ::~FullConnectionInt8Coder() { - FreeQuantParam(); - filter_tensor_ = nullptr; - bias_tensor_ = nullptr; - pack_a_ptr_ = nullptr; - pack_b_ptr_ = nullptr; - input_sums_ = nullptr; - weight_bias_sums_ = nullptr; - bias_ptr_ = nullptr; -} - -int FullConnectionInt8Coder::MallocQuantParam() { - filter_tensor_ = input_tensors_.at(kWeightIndex); - MS_CHECK_PTR(filter_tensor_); - std::vector weight_quant_params = filter_tensor_->quant_params(); - MS_CHECK_TRUE(!filter_tensor_->shape().empty(), "filter tensor shape is empty"); - int col = filter_tensor_->shape().front(); - filter_per_channel_ = (weight_quant_params.size() > 1); - init_size_ = filter_per_channel_ ? col : 1; - quant_.filter_scale_ = reinterpret_cast(malloc(init_size_ * sizeof(float))); - MS_CHECK_PTR(quant_.filter_scale_); - quant_.filter_zp_ = reinterpret_cast(malloc(init_size_ * sizeof(int32_t))); - MS_CHECK_PTR(quant_.filter_zp_); - quant_.left_shift_ = reinterpret_cast(malloc(init_size_ * sizeof(int32_t))); - MS_CHECK_PTR(quant_.left_shift_); - quant_.right_shift_ = reinterpret_cast(malloc(init_size_ * sizeof(int32_t))); - MS_CHECK_PTR(quant_.right_shift_); - quant_.quant_multiplier_ = reinterpret_cast(malloc(init_size_ * sizeof(int32_t))); - MS_CHECK_PTR(quant_.quant_multiplier_); - return RET_OK; -} - -void FullConnectionInt8Coder::FreeQuantParam() { - if (quant_.filter_scale_ != nullptr) { - free(quant_.filter_scale_); - quant_.filter_scale_ = nullptr; - } - if (quant_.filter_zp_ != nullptr) { - free(quant_.filter_zp_); - quant_.filter_zp_ = nullptr; - } - if (quant_.left_shift_ != nullptr) { - free(quant_.left_shift_); - quant_.left_shift_ = nullptr; - } - if (quant_.right_shift_ != nullptr) { - free(quant_.right_shift_); - quant_.right_shift_ = nullptr; - } - if (quant_.quant_multiplier_ != nullptr) { - free(quant_.quant_multiplier_); - quant_.quant_multiplier_ = nullptr; - } -} - -void FullConnectionInt8Coder::InitParam() { +int FullConnectionInt8Coder::ReSize(CoderContext *const context) { int row = 1; - int out_put_tensor_size = static_cast(output_tensor_->shape().size()); - for (int i = 0; i < out_put_tensor_size - 1; ++i) { + for (size_t i = 0; i < output_tensor_->shape().size() - 1; ++i) { row *= (output_tensor_->shape()).at(i); } - fc_param_->row_ = row; - fc_param_->col_ = output_tensor_->shape().back(); - fc_param_->deep_ = filter_tensor_->shape().at(1); - fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); - fc_param_->col_4_ = UP_ROUND(fc_param_->col_, C4NUM); - fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); - fc_param_->deep_16_ = UP_ROUND(fc_param_->deep_, C16NUM); - thread_count_ = MSMIN(thread_num_, UP_DIV(fc_param_->col_4_, C4NUM)); - thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_4_, C4NUM), thread_count_); -} - -int FullConnectionInt8Coder::ReSize(CoderContext *const context) { - InitParam(); - pack_a_ptr_size_ = static_cast(fc_param_->row_4_ * fc_param_->deep_16_ * sizeof(int8_t)); - pack_a_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, pack_a_ptr_size_, kOfflinePackWeight)); - MS_CHECK_PTR(pack_a_ptr_); - pack_b_ptr_size_ = static_cast(fc_param_->col_4_ * fc_param_->deep_16_ * sizeof(int8_t)); - weight_bias_sums_size_ = static_cast(fc_param_->col_4_ * sizeof(int)); - if (fc_param_->b_const_) { - pack_b_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, kOnlineSize, kOnlinePackWeight)); - MS_CHECK_PTR(pack_b_ptr_); - weight_bias_sums_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); - } else { - pack_b_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, pack_b_ptr_size_, kOfflinePackWeight)); - MS_CHECK_PTR(pack_b_ptr_); - weight_bias_sums_ = - reinterpret_cast(allocator_->Malloc(kNumberTypeInt, weight_bias_sums_size_, kOfflinePackWeight)); - } - MS_CHECK_PTR(weight_bias_sums_); - input_sums_size_ = static_cast(fc_param_->row_4_ * sizeof(int)); - input_sums_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt, input_sums_size_, kOfflinePackWeight)); - MS_CHECK_PTR(input_sums_); - if (input_tensors_.size() == kInputSize2) { - bias_ptr_size_ = static_cast(fc_param_->col_4_ * sizeof(int)); - bias_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); - MS_CHECK_PTR(bias_ptr_); - } else { - bias_ptr_ = nullptr; - } - NNaclInt8Serializer init_code; - if (input_tensors_.size() == kInputSize2) { - init_code.CodeMallocExpression(bias_ptr_, bias_ptr_size_); - init_code.CodeFunction("memset", bias_ptr_, 0, bias_ptr_size_); - init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, bias_ptr_); - } - if (fc_param_->b_const_) { - init_code.CodeMallocExpression(pack_b_ptr_, pack_b_ptr_size_); - init_code.CodeMallocExpression(weight_bias_sums_, weight_bias_sums_size_); - init_code.CodeFunction("RowMajor2Row16x4MajorInt8", filter_tensor_, pack_b_ptr_, fc_param_->col_, fc_param_->deep_); - init_code.CodeFunction("CalcWeightBiasSums", filter_tensor_, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, - quant_.filter_zp_, bias_ptr_, weight_bias_sums_, ColMajor, filter_per_channel_); - } - context->AppendInitCode(init_code.str()); + param_->row_ = row; + param_->col_ = output_tensor_->shape().back(); + param_->deep_ = (filter_tensor_->shape()).at(1); + MS_CHECK_RET_CODE(MatMulBaseInt8Coder::ReSize(context), "MatMulBaseInt8Coder::ReSize is nullptr"); return RET_OK; } -int FullConnectionInt8Coder::Init() { - fc_param_ = reinterpret_cast(parameter_); +int FullConnectionInt8Coder::Prepare(CoderContext *const context) { + // only support one thread currently + thread_count_ = thread_num_; + param_ = reinterpret_cast(parameter_); filter_tensor_ = input_tensors_.at(kWeightIndex); MS_CHECK_PTR(filter_tensor_); if (input_tensors_.size() == kInputSize2) { @@ -150,91 +44,16 @@ int FullConnectionInt8Coder::Init() { MS_CHECK_PTR(bias_tensor_); MS_CHECK_PTR(bias_tensor_->data_c()); } - fc_param_->a_const_ = (input_tensor_->data_c() != nullptr); - fc_param_->b_const_ = (filter_tensor_->data_c() != nullptr); - int ret = MallocQuantParam(); - if (ret != RET_OK) { - FreeQuantParam(); - return ret; - } - std::vector in_quant_params = input_tensor_->quant_params(); - MS_CHECK_TRUE(!in_quant_params.empty(), "in_quant_params empty is empty"); - quant_.input_.zp_ = in_quant_params.front().zeroPoint; - quant_.input_.scale_ = static_cast(in_quant_params.front().scale); - std::vector out_quant_params = output_tensor_->quant_params(); - MS_CHECK_TRUE(!out_quant_params.empty(), "out_quant_params empty is empty"); - quant_.output_.zp_ = out_quant_params.front().zeroPoint; - quant_.output_.scale_ = static_cast(out_quant_params.front().scale); - - int weight_quant_num = filter_per_channel_ ? static_cast(filter_tensor_->shape().front()) : 1; - std::vector weight_quant_params = filter_tensor_->quant_params(); - MS_CHECK_TRUE(!weight_quant_params.empty(), "weight_quant_params empty is empty"); - for (int i = 0; i < weight_quant_num; i++) { - quant_.filter_zp_[i] = weight_quant_params[i].zeroPoint; - quant_.filter_scale_[i] = static_cast(weight_quant_params[i].scale); - } - - for (int i = 0; i < weight_quant_num; ++i) { - auto in_scale = static_cast(quant_.input_.scale_ * quant_.filter_scale_[i]); - double real_multiplier = in_scale / static_cast(quant_.output_.scale_); - QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], - &quant_.right_shift_[i]); - } - CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, - quant_.output_.zp_, quant_.output_.scale_, &quant_.out_act_min_, - &quant_.out_act_max_); - return RET_OK; -} - -int FullConnectionInt8Coder::Prepare(CoderContext *const context) { - // only support one thread currently - thread_count_ = thread_num_; - MS_CHECK_RET_CODE(Init(), "FullConnectionInt8Coder init failed"); + param_->batch = 1; + param_->a_transpose_ = false; + param_->b_transpose_ = true; + MatMulBaseInt8Coder::InitParameter(); + MS_CHECK_RET_CODE(MatMulBaseInt8Coder::Init(), "Init failed"); return ReSize(context); } int FullConnectionInt8Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/common_func.h", "nnacl/int8/common_func_int8.h", "nnacl/int8/matmul_int8.h"}, - {"common_func.c", "common_func_int8.c", "matmul_int8.c"}); - - NNaclInt8Serializer code; - code.precision(kPrecision); - code.CodeFunction("memset", input_sums_, 0, input_sums_size_); - code.CodeFunction("memset", pack_a_ptr_, 0, pack_a_ptr_size_); - code.CodeFunction("RowMajor2Row16x4MajorInt8", input_tensor_, pack_a_ptr_, fc_param_->row_, fc_param_->deep_); - int32_t tmp_weight_zp = filter_per_channel_ ? 1 : quant_.filter_zp_[0]; - code.CodeFunction("CalcInputSums", input_tensor_, fc_param_->row_, fc_param_->deep_, tmp_weight_zp, input_sums_, - RowMajor); - - if (!fc_param_->b_const_) { - code.CodeFunction("memset", pack_b_ptr_, 0, pack_b_ptr_size_); - code.CodeFunction("memset", weight_bias_sums_, 0, weight_bias_sums_size_); - code.CodeFunction("RowMajor2Row16x4MajorInt8", filter_tensor_, pack_b_ptr_, fc_param_->col_, fc_param_->deep_); - code.CodeFunction("CalcWeightBiasSums", filter_tensor_, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, - quant_.filter_zp_, bias_ptr_, weight_bias_sums_, ColMajor, filter_per_channel_); - } - int stride = thread_stride_ * C4NUM; - int res_stride = fc_param_->col_; - int cur_oc = MSMIN(stride, res_stride); - if (cur_oc <= 0) { - return RET_OK; - } - int32_t *cur_left = quant_.left_shift_; - int32_t *cur_right = quant_.right_shift_; - int32_t *cur_mul = quant_.quant_multiplier_; - int32_t *cur_zp = quant_.filter_zp_; - - code.CodeArray("cur_left_shift", cur_left, init_size_, true); - code.CodeArray("cur_right_shift", cur_right, init_size_, true); - code.CodeArray("cur_multiplier", cur_mul, init_size_, true); - code.CodeArray("cur_filter_zp", cur_zp, init_size_, true); - - code.CodeFunction("MatmulInt8Opt", pack_a_ptr_, pack_b_ptr_, output_tensor_->data_c(), fc_param_->row_, cur_oc, - fc_param_->deep_16_, input_sums_, weight_bias_sums_, quant_.out_act_min_, quant_.out_act_max_, - quant_.output_.zp_, "&cur_multiplier", "&cur_left_shift", "&cur_right_shift", fc_param_->col_, - filter_per_channel_, "&cur_filter_zp"); - MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called"; - context->AppendCode(code.str()); + MS_CHECK_RET_CODE(MatMulBaseInt8Coder::DoCode(context), "matmul int8 do code failed"); return RET_OK; } diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h b/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h index 3dfcfd51b5..dd12382f20 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h @@ -20,49 +20,25 @@ #include #include #include -#include "coder/opcoders/op_coder.h" +#include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" #include "nnacl/int8/quantize.h" #include "nnacl/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { -class FullConnectionInt8Coder final : public OperatorCoder { +class FullConnectionInt8Coder final : public MatMulBaseInt8Coder { public: FullConnectionInt8Coder(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, size_t node_index, Target target) - : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} + : MatMulBaseInt8Coder(in_tensors, out_tensors, node, node_index, target) {} - ~FullConnectionInt8Coder() override; + ~FullConnectionInt8Coder() override = default; int Prepare(CoderContext *const context) override; int DoCode(CoderContext *const context) override; private: - int Init(); - int ReSize(CoderContext *const context); - int MallocQuantParam(); - void FreeQuantParam(); - void InitParam(); - - private: - MatmulQuantParameter quant_{0}; - MatMulParameter *fc_param_{nullptr}; - Tensor *filter_tensor_{nullptr}; - Tensor *bias_tensor_{nullptr}; - size_t pack_a_ptr_size_{0}; - int8_t *pack_a_ptr_ = nullptr; - size_t pack_b_ptr_size_{0}; - int8_t *pack_b_ptr_{nullptr}; - size_t input_sums_size_{0}; - int *input_sums_{nullptr}; - size_t weight_bias_sums_size_{0}; - int *weight_bias_sums_{nullptr}; - size_t bias_ptr_size_{0}; - int *bias_ptr_{nullptr}; - int thread_count_{1}; - int thread_stride_{0}; - bool filter_per_channel_{true}; - int init_size_{0}; + int ReSize(CoderContext *const context) override; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_FULLCONNECTION_INT8_CODER_H_ diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc new file mode 100644 index 0000000000..27c04a6634 --- /dev/null +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc @@ -0,0 +1,255 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" +#include +#include +#include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" +#include "coder/opcoders/file_collector.h" +namespace mindspore::lite::micro::nnacl { + +int MatMulBaseInt8Coder::ReSize(CoderContext *const context) { + ResizeParameter(); + if (InitTmpBuffer() != RET_OK) { + FreeQuantParam(); + return RET_ERROR; + } + return RET_OK; +} + +int MatMulBaseInt8Coder::InitTmpBuffer() { + a_pack_ptr_size_ = param_->row_align_ * param_->deep_16_ * sizeof(int8_t); + pack_a_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, a_pack_ptr_size_, kWorkspace)); + MS_CHECK_PTR(pack_a_ptr_); + b_pack_ptr_size_ = param_->batch * param_->col_align_ * param_->deep_16_ * sizeof(int8_t); + if (param_->b_const_) { + pack_b_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, kOnlineSize, kOnlinePackWeight)); + } else { + pack_b_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, b_pack_ptr_size_, kWorkspace)); + } + MS_CHECK_PTR(pack_b_ptr_); + input_sums_size_ = static_cast(param_->row_align_ * sizeof(int)); + input_sums_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt32, input_sums_size_, kWorkspace)); + MS_CHECK_PTR(input_sums_); + weight_bias_sums_size_ = static_cast(param_->batch * param_->col_align_ * sizeof(int)); + if (param_->b_const_) { + weight_bias_sums_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt32, kOnlineSize, kOnlinePackWeight)); + } else { + weight_bias_sums_ = + reinterpret_cast(allocator_->Malloc(kNumberTypeInt32, weight_bias_sums_size_, kWorkspace)); + } + MS_CHECK_PTR(weight_bias_sums_); + return RET_OK; +} + +MatMulBaseInt8Coder::~MatMulBaseInt8Coder() { FreeQuantParam(); } + +void MatMulBaseInt8Coder::ResizeParameter() { + param_->row_align_ = UP_ROUND(param_->row_, C4NUM); + param_->col_align_ = UP_ROUND(param_->col_, C4NUM); + param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); + thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, C4NUM)); + thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_); +} + +void MatMulBaseInt8Coder::FreeQuantParam() { + if (quant_.filter_scale_ != nullptr) { + free(quant_.filter_scale_); + quant_.filter_scale_ = nullptr; + } + if (quant_.filter_zp_ != nullptr) { + free(quant_.filter_zp_); + quant_.filter_zp_ = nullptr; + } + if (quant_.left_shift_ != nullptr) { + free(quant_.left_shift_); + quant_.left_shift_ = nullptr; + } + if (quant_.right_shift_ != nullptr) { + free(quant_.right_shift_); + quant_.right_shift_ = nullptr; + } + if (quant_.quant_multiplier_ != nullptr) { + free(quant_.quant_multiplier_); + quant_.quant_multiplier_ = nullptr; + } +} + +int MatMulBaseInt8Coder::MallocQuantParam() { + std::vector weight_quant_params = filter_tensor_->quant_params(); + int col = filter_tensor_->shape().front(); + filter_per_channel_ = (weight_quant_params.size() > 1); + weight_quant_num_ = filter_per_channel_ ? col : 1; + quant_.filter_scale_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(float))); + MS_CHECK_PTR(quant_.filter_scale_); + quant_.filter_zp_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(int32_t))); + MS_CHECK_PTR(quant_.filter_zp_); + quant_.left_shift_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(int32_t))); + MS_CHECK_PTR(quant_.left_shift_); + quant_.right_shift_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(int32_t))); + MS_CHECK_PTR(quant_.right_shift_); + quant_.quant_multiplier_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(int32_t))); + MS_CHECK_PTR(quant_.quant_multiplier_); + return RET_OK; +} + +int MatMulBaseInt8Coder::InitQuantParam() { + std::vector in_quant_params = input_tensor_->quant_params(); + MS_CHECK_TRUE(!in_quant_params.empty(), "in_quant_params is empty"); + quant_.input_.zp_ = in_quant_params.front().zeroPoint; + quant_.input_.scale_ = static_cast(in_quant_params.front().scale); + + std::vector out_quant_params = output_tensor_->quant_params(); + MS_CHECK_TRUE(!out_quant_params.empty(), "out_quant_params is empty"); + quant_.output_.zp_ = out_quant_params.front().zeroPoint; + quant_.output_.scale_ = static_cast(out_quant_params.front().scale); + std::vector weight_quant_params = filter_tensor_->quant_params(); + for (int i = 0; i < weight_quant_num_; i++) { + quant_.filter_zp_[i] = weight_quant_params[i].zeroPoint; + quant_.filter_scale_[i] = static_cast(weight_quant_params[i].scale); + } + + for (int i = 0; i < weight_quant_num_; ++i) { + const auto in_scale = static_cast(quant_.input_.scale_ * quant_.filter_scale_[i]); + double real_multiplier = in_scale / static_cast(quant_.output_.scale_); + QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], + &quant_.right_shift_[i]); + } + + CalculateActivationRangeQuantized(param_->act_type_ == ActType_Relu, param_->act_type_ == ActType_Relu6, + quant_.output_.zp_, quant_.output_.scale_, &quant_.out_act_min_, + &quant_.out_act_max_); + return RET_OK; +} + +void MatMulBaseInt8Coder::InitParameter() { + param_->a_const_ = (input_tensor_ != nullptr); + param_->b_const_ = (filter_tensor_ != nullptr); +} + +int MatMulBaseInt8Coder::InitBias() { + if (bias_tensor_ != nullptr) { + int max_bias_data_elements = UP_ROUND(bias_tensor_->ElementsNum(), C4NUM); + // to pack to init + bias_ptr_size_ = static_cast(max_bias_data_elements * sizeof(int)); + bias_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt32, kOnlineSize, kOnlinePackWeight)); + MS_CHECK_PTR(bias_ptr_); + } + return RET_OK; +} + +int MatMulBaseInt8Coder::Init() { + if (MallocQuantParam() != RET_OK) { + FreeQuantParam(); + return RET_ERROR; + } + MS_CHECK_RET_CODE(InitQuantParam(), "matmul int8 init quant_param failed"); + if (InitBias() != RET_OK) { + FreeQuantParam(); + return RET_ERROR; + } + return RET_OK; +} + +int MatMulBaseInt8Coder::Prepare(CoderContext *const context) { return RET_OK; } + +int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { + Collect(context, + {"nnacl/common_func.h", "nnacl/int8/common_func_int8.h", "nnacl/int8/matmul_int8.h", + "wrapper/int8/matmul_int8_wrapper.h"}, + {"common_func.c", "common_func_int8.c", "matmul_int8.c", "matmul_int8_wrapper.c"}); + std::string value_str_end = ";\n"; + NNaclInt8Serializer init_code; + NNaclInt8Serializer code; + if (bias_ptr_) { + init_code.CodeMallocExpression(bias_ptr_, bias_ptr_size_); + init_code.CodeFunction("memset", bias_ptr_, 0, bias_ptr_size_); + init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, bias_ptr_size_); + } + if (param_->b_const_) { + init_code.CodeMallocExpression(weight_bias_sums_, weight_bias_sums_size_); + init_code.CodeFunction("memset", weight_bias_sums_, 0, weight_bias_sums_size_); + init_code.CodeMallocExpression(pack_b_ptr_, b_pack_ptr_size_); + init_code.CodeFunction("memset", pack_b_ptr_, 0, b_pack_ptr_size_); + init_code.CodeArray("init_filter_zp", quant_.filter_zp_, weight_quant_num_); + init_code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, + param_->deep_, param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, + "init_filter_zp", bias_ptr_, param_->b_transpose_, filter_per_channel_); + } else { + code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, param_->deep_, + param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, "init_filter_zp", + bias_ptr_, param_->b_transpose_, filter_per_channel_); + } + int task_id = 0; + std::string a_ptr_str = allocator_->GetRuntimeAddr(input_tensor_); + std::string c_ptr_str = allocator_->GetRuntimeAddr(output_tensor_); + std::string pack_b_ptr_str = allocator_->GetRuntimeAddr(pack_b_ptr_); + std::string weight_bias_sums_str = allocator_->GetRuntimeAddr(weight_bias_sums_); + code.precision(kPrecision); + std::string tmp_weight_zp = + "int32_t tmp_weight_zp = " + (filter_per_channel_ ? std::to_string(1) : std::to_string(quant_.filter_zp_[0])); + code << tmp_weight_zp << value_str_end; + for (int i = 0; i < param_->batch; i++) { + std::string current_src_a = a_ptr_str + "+" + std::to_string(i * param_->row_ * param_->deep_); + if (param_->a_transpose_) { + code.CodeFunction("RowMajor2Col16x4MajorInt8", current_src_a, param_->deep_, param_->row_, pack_a_ptr_); + code.CodeFunction("CalcInputSums", current_src_a, param_->row_, param_->deep_, "tmp_weight_zp", input_sums_, + ColMajor); + } else { + code.CodeFunction("RowMajor2Row16x4MajorInt8", current_src_a, pack_a_ptr_, param_->row_, param_->deep_); + code.CodeFunction("CalcInputSums", current_src_a, param_->row_, param_->deep_, "tmp_weight_zp", input_sums_, + RowMajor); + } + std::string batch_b_ptr_str = pack_b_ptr_str + "+" + std::to_string(i * param_->col_align_ * param_->deep_16_); + std::string batch_c_ptr_str = c_ptr_str + "+" + std::to_string(i * param_->row_ * param_->col_); + + int stride = thread_stride_ * C4NUM; + int cur_stride = task_id * stride; + int res_stride = param_->col_ - cur_stride; + int cur_oc = MSMIN(stride, res_stride); + if (cur_oc <= 0) { + return RET_OK; + } + code.CodeStruct("matmul_quant_parameter", quant_, weight_quant_num_); + std::string cur_left = "int32_t *cur_left = matmul_quant_parameter.left_shift_"; + std::string cur_right = "int32_t *cur_right = matmul_quant_parameter.right_shift_"; + std::string cur_mul = "int32_t *cur_mul = matmul_quant_parameter.quant_multiplier_ "; + std::string cur_zp = "int32_t *cur_zp = matmul_quant_parameter.filter_zp_ "; + if (filter_per_channel_) { + code << cur_left << " + " << cur_stride << value_str_end; + code << cur_right << " + " << cur_stride << value_str_end; + code << cur_mul << " + " << cur_stride << value_str_end; + code << cur_zp << " + " << cur_stride << value_str_end; + } else { + code << cur_left << value_str_end; + code << cur_right << value_str_end; + code << cur_mul << value_str_end; + code << cur_zp << value_str_end; + } + std::string batch_b_ptr_str_final = batch_b_ptr_str + " + " + std::to_string(cur_stride * param_->deep_16_); + std::string batch_c_ptr_final = batch_c_ptr_str + "+" + std::to_string(cur_stride); + std::string weight_bias_sums_str_final = weight_bias_sums_str + "+" + std::to_string(cur_stride); + code.CodeFunction("MatmulInt8Opt", pack_a_ptr_, batch_b_ptr_str_final, batch_c_ptr_final, param_->row_, cur_oc, + param_->deep_16_, input_sums_, weight_bias_sums_str_final, quant_.out_act_min_, + quant_.out_act_max_, quant_.output_.zp_, "cur_mul", "cur_left", "cur_right", param_->col_, + filter_per_channel_, "cur_zp"); + } + MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called"; + context->AppendInitCode(init_code.str()); + context->AppendCode(code.str()); + return RET_OK; +} +} // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h new file mode 100644 index 0000000000..5f0dbd09de --- /dev/null +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ +#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ +#include +#include "coder/opcoders/op_coder.h" +#include "nnacl/matmul_parameter.h" +namespace mindspore::lite::micro::nnacl { +class MatMulBaseInt8Coder : public OperatorCoder { + public: + MatMulBaseInt8Coder(const std::vector &in_tensors, const std::vector &out_tensors, + const Model::Node *node, size_t node_index, Target target) + : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} + ~MatMulBaseInt8Coder() override; + + int Prepare(CoderContext *const context) override; + + int DoCode(CoderContext *const context) override; + + protected: + int Init(); + void InitParameter(); + virtual int ReSize(CoderContext *const context); + + private: + void ResizeParameter(); + int MallocQuantParam(); + void FreeQuantParam(); + int InitQuantParam(); + int InitBias(); + int InitTmpBuffer(); + + protected: + bool filter_per_channel_{true}; + int thread_count_{1}; + int thread_stride_{0}; + MatmulQuantParameter quant_{}; + Tensor *filter_tensor_{nullptr}; + Tensor *bias_tensor_{nullptr}; + MatMulParameter *param_{nullptr}; + size_t a_pack_ptr_size_{0}; + size_t b_pack_ptr_size_{0}; + int8_t *pack_a_ptr_{nullptr}; + int8_t *pack_b_ptr_{nullptr}; + size_t bias_ptr_size_{0}; + int *bias_ptr_{nullptr}; + size_t input_sums_size_{0}; + int *input_sums_{nullptr}; + size_t weight_bias_sums_size_{0}; + int *weight_bias_sums_{nullptr}; + + private: + int weight_quant_num_{0}; +}; +} // namespace mindspore::lite::micro::nnacl +#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.cc index 3e49164b42..d3bc619591 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.cc @@ -15,117 +15,10 @@ */ #include "coder/opcoders/nnacl/int8/matmul_int8_coder.h" -#include -#include -#include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" -#include "coder/opcoders/file_collector.h" +#include "coder/opcoders/op_coder.h" +using mindspore::schema::PrimitiveType_MatMul; namespace mindspore::lite::micro::nnacl { - -int MatMulInt8Coder::ReSize(CoderContext *const context) { - int batch = 1; - std::vector x_shape = input_tensor_->shape(); - std::vector o_shape = output_tensor_->shape(); - if (x_shape.size() <= 2 || o_shape.size() <= 2) { - MS_LOG(ERROR) << "x_shape.size() or o_shape.size() is less than two"; - return RET_ERROR; - } - for (size_t i = 0; i < x_shape.size() - 2; ++i) { - batch *= x_shape.at(i); - } - params_->batch = batch; - params_->row_ = o_shape.at(o_shape.size() - 2); - params_->col_ = o_shape.at(o_shape.size() - 1); - params_->deep_ = params_->a_transpose_ ? x_shape.at(x_shape.size() - 2) : x_shape.at(x_shape.size() - 1); - params_->row_4_ = UP_ROUND(params_->row_, C4NUM); - params_->col_4_ = UP_ROUND(params_->col_, C4NUM); - params_->deep_16_ = UP_ROUND(params_->deep_, C16NUM); - - a_pack_ptr_size_ = static_cast(params_->row_4_ * params_->deep_16_ * sizeof(int8_t)); - a_pack_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, a_pack_ptr_size_, kOfflinePackWeight)); - MS_CHECK_PTR(a_pack_ptr_); - input_sums_size_ = static_cast(params_->row_4_ * sizeof(int)); - input_sums_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt, input_sums_size_, kOfflinePackWeight)); - MS_CHECK_PTR(input_sums_); - b_pack_batch_ptr_size_ = static_cast(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)); - if (params_->b_const_) { - b_pack_batch_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, kOnlineSize, kOnlinePackWeight)); - MS_CHECK_PTR(b_pack_batch_ptr_); - weight_bias_sums_batch_ = - reinterpret_cast(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); - } else { - b_pack_batch_ptr_ = - reinterpret_cast(allocator_->Malloc(kNumberTypeInt8, b_pack_batch_ptr_size_, kOfflinePackWeight)); - MS_CHECK_PTR(b_pack_batch_ptr_); - weight_bias_sums_batch_ = - reinterpret_cast(allocator_->Malloc(kNumberTypeInt, weight_bias_sums_batch_size_, kOfflinePackWeight)); - } - MS_CHECK_PTR(weight_bias_sums_batch_); - if (input_tensors_.size() == 3) { - bias_prt_size_ = static_cast(params_->col_4_ * sizeof(int)); - bias_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); - MS_CHECK_PTR(bias_ptr_); - } else { - bias_ptr_ = nullptr; - } - thread_count_ = MSMIN(thread_num_, UP_DIV(params_->col_4_, C4NUM)); - thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, C4NUM), thread_count_); - - std::vector params = input_tensor_->quant_params(); - MS_CHECK_TRUE(!params.empty(), "params is empty"); - quant_params_.input_.zp_ = params.front().zeroPoint; - quant_params_.input_.scale_ = static_cast(params.front().scale); - - params = filter_tensor_->quant_params(); - MS_CHECK_TRUE(!params.empty(), "params is empty"); - quant_params_.weight_.zp_ = params.front().zeroPoint; - quant_params_.weight_.scale_ = static_cast(params.front().scale); - - params = output_tensor_->quant_params(); - MS_CHECK_TRUE(!params.empty(), "params is empty"); - quant_params_.output_.zp_ = params.front().zeroPoint; - quant_params_.output_.scale_ = static_cast(params.front().scale); - double real_multiplier = quant_params_.input_.scale_ * quant_params_.weight_.scale_ / quant_params_.output_.scale_; - QuantizeRoundParameterWithDoublePrecision(real_multiplier, quant_params_.quant_multiplier_, quant_params_.left_shift_, - quant_params_.right_shift_); - if (params_->b_const_) { - NNaclInt8Serializer init_code; - if (bias_ptr_) { - init_code.CodeMallocExpression(bias_ptr_, bias_prt_size_); - init_code.CodeFunction("memset", bias_ptr_, 0, bias_prt_size_); - init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_->data_c(), bias_prt_size_); - } - init_code.CodeMallocExpression(weight_bias_sums_batch_, weight_bias_sums_batch_size_); - init_code.CodeFunction("memset", weight_bias_sums_batch_, 0, weight_bias_sums_batch_size_); - init_code.CodeMallocExpression(b_pack_batch_ptr_, b_pack_batch_ptr_size_); - init_code.CodeFunction("memset", b_pack_batch_ptr_, 0, b_pack_batch_ptr_size_); - - init_code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n"; - init_code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_, - params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_, - quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_); - context->AppendInitCode(init_code.str()); - } - return RET_OK; -} - -MatMulInt8Coder::~MatMulInt8Coder() { - if (quant_params_.quant_multiplier_ != nullptr) { - free(quant_params_.quant_multiplier_); - quant_params_.quant_multiplier_ = nullptr; - } - if (quant_params_.right_shift_ != nullptr) { - free(quant_params_.right_shift_); - quant_params_.right_shift_ = nullptr; - } - if (quant_params_.left_shift_ != nullptr) { - free(quant_params_.left_shift_); - quant_params_.left_shift_ = nullptr; - } - return; -} - -int MatMulInt8Coder::Init() { - params_ = reinterpret_cast(parameter_); +int MatMulInt8Coder::Prepare(CoderContext *const context) { filter_tensor_ = input_tensors_.at(kWeightIndex); MS_CHECK_PTR(filter_tensor_); if (input_tensors_.size() == kInputSize2) { @@ -133,85 +26,34 @@ int MatMulInt8Coder::Init() { MS_CHECK_PTR(bias_tensor_); MS_CHECK_PTR(bias_tensor_->data_c()); } - params_->b_const_ = (filter_tensor_->data_c() != nullptr); - - quant_params_.quant_multiplier_ = reinterpret_cast(malloc(1 * sizeof(int32_t))); - MS_CHECK_PTR(quant_params_.quant_multiplier_); - quant_params_.left_shift_ = reinterpret_cast(malloc(1 * sizeof(int32_t))); - MS_CHECK_PTR(quant_params_.left_shift_); - quant_params_.right_shift_ = reinterpret_cast(malloc(1 * sizeof(int32_t))); - MS_CHECK_PTR(quant_params_.right_shift_); - - return RET_OK; + param_ = reinterpret_cast(parameter_); + MatMulBaseInt8Coder::InitParameter(); + MS_CHECK_RET_CODE(MatMulBaseInt8Coder::Init(), "ParallelLaunch failed"); + return ReSize(context); } -int MatMulInt8Coder::Prepare(CoderContext *const context) { - MS_CHECK_RET_CODE(Init(), "MatMulInt8Coder Init failed"); - MS_CHECK_RET_CODE(ReSize(context), "MatMulInt8Coder ReSize failed"); +int MatMulInt8Coder::ReSize(CoderContext *const context) { + int batch = 1; + std::vector x_shape = input_tensor_->shape(); + std::vector o_shape = output_tensor_->shape(); + MS_CHECK_RET_CODE(x_shape.size() >= kBiasIndex, "x_shape size is less than two"); + for (size_t i = 0; i < x_shape.size() - kBiasIndex; ++i) { + batch *= x_shape[i]; + } + param_->batch = batch; + MS_CHECK_RET_CODE(o_shape.size() >= kBiasIndex, "o_shape size is less than two"); + param_->row_ = o_shape[o_shape.size() - kBiasIndex]; + param_->col_ = o_shape[o_shape.size() - kWeightIndex]; + param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kBiasIndex] : x_shape[x_shape.size() - kWeightIndex]; + MS_CHECK_RET_CODE(MatMulBaseInt8Coder::ReSize(context), "MatMulBaseInt8Coder::ReSize is nullptr"); return RET_OK; } int MatMulInt8Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/common_func.h", "nnacl/int8/common_func_int8.h", "nnacl/int8/matmul_int8.h"}, - {"common_func.c", "common_func_int8.c", "matmul_int8.c"}); - - std::string a_ptr_str = allocator_->GetRuntimeAddr(input_tensor_); - std::string c_ptr_str = allocator_->GetRuntimeAddr(output_tensor_); - int a_stride = params_->row_ * params_->deep_; - int c_stride = params_->row_ * params_->col_; - - NNaclInt8Serializer code; - code.precision(kPrecision); - int task_id = 0; - int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_4_, C4NUM) - task_id * thread_stride_); - if (cur_oc <= 0) { - return RET_OK; - } - code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n"; - if (!params_->b_const_) { - code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_, - params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_, - quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_); - } - std::string b_batch_str = allocator_->GetRuntimeAddr(b_pack_batch_ptr_); - std::string weight_bias_sums_batch_str = allocator_->GetRuntimeAddr(weight_bias_sums_batch_); - code.CodeFunction("memset", input_sums_, 0, input_sums_size_); - code.CodeFunction("memset", a_pack_ptr_, 0, a_pack_ptr_size_); - code << "for (int i = 0; i < " << params_->batch << "; ++i) {\n"; - code << " int8_t* cur_a_ptr = " << a_ptr_str << " + i * " << a_stride << ";\n"; - if (params_->a_transpose_) { - code.CodeFunction("RowMajor2Col16x4MajorInt8", "cur_a_ptr", params_->deep_, params_->row_, a_pack_ptr_); - code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->deep_, quant_params_.weight_.zp_, input_sums_, ColMajor); - } else { - code.CodeFunction("RowMajor2Row16x4MajorInt8", "cur_a_ptr", a_pack_ptr_, params_->row_, params_->deep_); - code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->row_, params_->deep_, quant_params_.weight_.zp_, - input_sums_, RowMajor); - } - code << " b_pack_ptr_ = " << b_batch_str << " + i * " << params_->col_4_ * params_->deep_16_ << ";\n"; - code << " weight_bias_sums_ = " << weight_bias_sums_batch_str << " + i * " << params_->col_4_ << ";\n"; - code << " c_ptr_ = " << c_ptr_str << " + i * " << c_stride << ";\n"; - int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM); - - code << " int8_t* cur_b = b_pack_ptr_ + " << task_id * thread_stride_ * C4NUM * params_->deep_16_ << ";\n"; - code << " int32_t* cur_bias = weight_bias_sums_ + " << task_id * thread_stride_ * C4NUM << ";\n"; - code << " int8_t *cur_c = c_ptr_ + " << task_id * thread_stride_ * C4NUM << ";\n"; - code << " static const int left_shift = " << quant_params_.left_shift_[0] << ";\n"; - code << " static const int right_shift = " << quant_params_.right_shift_[0] << ";\n"; - code << " static const int quant_multiplier = " << quant_params_.quant_multiplier_[0] << ";\n"; - if (target_ == kARM64) { - code.CodeFunction("MatmulInt8Neon64", "cur_a_ptr", "cur_b", "cur_c", params_->row_4_, cur_oc * C4NUM, - params_->deep_16_, input_sums_, "cur_bias", INT8_MIN, INT8_MAX, quant_params_.output_.zp_, - "&quant_multiplier", "&left_shift", "&right_shift", params_->row_, cur_oc_res, params_->col_, - false); - } else { - code.CodeFunction("MatMulInt8_16x4_r", "cur_a_ptr", "cur_b", "cur_c", params_->row_, cur_oc_res, params_->deep_16_, - params_->col_, input_sums_, "cur_bias", "&left_shift", "&right_shift", "&quant_multiplier", - quant_params_.output_.zp_, INT8_MIN, INT8_MAX, false); - } - code << "}\n"; - MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called"; - context->AppendCode(code.str()); - + MS_CHECK_RET_CODE(MatMulBaseInt8Coder::DoCode(context), "matmul int8 do code failed"); return RET_OK; } + +REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_MatMul, CPUOpCoderCreator) + } // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h index b1d8859fdb..482c86c95f 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h @@ -16,42 +16,26 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ + #include #include "coder/opcoders/op_coder.h" +#include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" #include "nnacl/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { -class MatMulInt8Coder final : public OperatorCoder { +class MatMulInt8Coder final : public MatMulBaseInt8Coder { public: MatMulInt8Coder(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, size_t node_index, Target target) - : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} - ~MatMulInt8Coder() override; + : MatMulBaseInt8Coder(in_tensors, out_tensors, node, node_index, target) {} + + ~MatMulInt8Coder() override = default; int Prepare(CoderContext *const context) override; int DoCode(CoderContext *const context) override; private: - int Init(); - int ReSize(CoderContext *const context); - - private: - Tensor *filter_tensor_{nullptr}; - Tensor *bias_tensor_{nullptr}; - MatMulParameter *params_{nullptr}; - MatmulQuantParameter quant_params_{0}; - size_t a_pack_ptr_size_{0}; - int8_t *a_pack_ptr_{nullptr}; - size_t b_pack_batch_ptr_size_{0}; - int8_t *b_pack_batch_ptr_{nullptr}; - size_t bias_prt_size_{0}; - int *bias_ptr_{nullptr}; - size_t input_sums_size_{0}; - int *input_sums_{nullptr}; - size_t weight_bias_sums_batch_size_{0}; - int *weight_bias_sums_batch_{nullptr}; - int thread_stride_{0}; - int thread_count_{0}; + int ReSize(CoderContext *const context) override; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ diff --git a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc index 82eca83215..c516cab451 100644 --- a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc +++ b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc @@ -198,11 +198,16 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ReshapeQuant reshape_quant_arg.output_activation_min_, reshape_quant_arg.output_activation_max_); } -void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg) { +void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg, + int weight_quant_num) { + CodeArray("filter_scale", matmul_quant_arg.filter_scale_, weight_quant_num); + CodeArray("filter_zp", matmul_quant_arg.filter_zp_, weight_quant_num); + CodeArray("left_shift", matmul_quant_arg.left_shift_, weight_quant_num); + CodeArray("right_shift", matmul_quant_arg.right_shift_, weight_quant_num); + CodeArray("multiplier", matmul_quant_arg.quant_multiplier_, weight_quant_num); CodeBaseStruct("MatmulQuantParameter", name, matmul_quant_arg.input_, matmul_quant_arg.weight_, - matmul_quant_arg.output_, matmul_quant_arg.out_act_min_, matmul_quant_arg.out_act_max_, - matmul_quant_arg.left_shift_[0], matmul_quant_arg.right_shift_[0], - matmul_quant_arg.quant_multiplier_[0]); + matmul_quant_arg.output_, matmul_quant_arg.out_act_min_, matmul_quant_arg.out_act_max_, "filter_scale", + "filter_zp", "left_shift", "right_shift", "multiplier"); } void NNaclInt8Serializer::CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg) { diff --git a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h index a31d367048..80c723cab0 100644 --- a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h +++ b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h @@ -52,7 +52,7 @@ class NNaclInt8Serializer : public Serializer { void CodeStruct(const std::string &name, const ::QuantMulArg &quant_mul_arg); void CodeStruct(const std::string &name, const ReduceQuantArg &reduce_quant_arg); void CodeStruct(const std::string &name, const ReshapeQuantArg &reshape_quant_arg); - void CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg); + void CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg, int weight_quant_num); void CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg); void CodeStruct(const std::string &name, const DivQuantArg &div_quant_arg); void CodeStruct(const std::string &name, const ReluXQuantArg &relu_quant_arg); diff --git a/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c b/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c index 6c8208777f..c51d3cafb3 100644 --- a/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c +++ b/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.c @@ -30,15 +30,16 @@ void InitInt8MatrixA(int8_t *src_ptr, int32_t *input_sums, int8_t *dst_ptr, int } } -void InitInt8MatrixB(int8_t *src_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, int col, - int col_4, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, bool b_transpose) { +void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, + int col, int col_align, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, + bool b_transpose, bool filter_per_channel) { for (int i = 0; i < batch; ++i) { - int8_t *cur_b = src_ptr + i * deep * col; - int8_t *cur_b_pack = dst_ptr + i * col_4 * deep_16; - int32_t *cur_sums = weight_bias_sums_batch_ + i * col_4; + int8_t *cur_b = weight_ptr + i * deep * col; + int8_t *cur_b_pack = dst_ptr + i * col_align * deep_16; + int32_t *cur_sums = weight_bias_sums_batch_ + i * col_align; if (b_transpose) { RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, col, deep); - CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, ColMajor, false); + CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, ColMajor, filter_per_channel); } else { RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack); CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false); diff --git a/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.h b/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.h index 0c9f48cb48..72fab6a069 100644 --- a/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.h +++ b/mindspore/lite/micro/coder/operator_library/wrapper/int8/matmul_int8_wrapper.h @@ -25,8 +25,9 @@ extern "C" { void InitInt8MatrixA(int8_t *src_ptr, int32_t *input_sums, int8_t *dst_ptr, int batch, int row, int deep, int input_zp, const int *weight_zp, bool a_transpose); -void InitInt8MatrixB(int8_t *src_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, int col, - int col_4, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, bool b_transpose); +void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, + int col, int col_align, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, + bool b_transpose, bool filter_per_channel); #ifdef __cplusplus }