| @@ -97,6 +97,7 @@ set(CODER_OPCODERS_SRC | |||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/concat_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/concat_int8_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc | |||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/matmul_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/matmul_int8_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc | ||||
| @@ -9,6 +9,8 @@ file(GLOB KERNEL_SRC | |||||
| if(MICRO_BUILD_ARM64) | if(MICRO_BUILD_ARM64) | ||||
| file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm64/*.S) | file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm64/*.S) | ||||
| file(GLOB OPT_SRC ${NNACL_DIR}/assembly/opt/*.S) | |||||
| list(APPEND ASSEMBLY_SRC ${OPT_SRC}) | |||||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | ||||
| endif() | endif() | ||||
| @@ -127,7 +127,8 @@ int Conv2D3x3Int8Coder::Prepare(CoderContext *const context) { | |||||
| } | } | ||||
| int Conv2D3x3Int8Coder::DoCode(CoderContext *const context) { | int Conv2D3x3Int8Coder::DoCode(CoderContext *const context) { | ||||
| Collect(context, {"nnacl/int8/conv_int8.h"}, {"pack_int8.c", "conv_int8.c", "fixed_point.c"}); | |||||
| Collect(context, {"nnacl/int8/conv_int8.h", "nnacl/int8/conv3x3_int8.h"}, | |||||
| {"pack_int8.c", "conv_int8.c", "conv3x3_int8.c", "fixed_point.c"}); | |||||
| nnacl::NNaclInt8Serializer code; | nnacl::NNaclInt8Serializer code; | ||||
| code.precision(kPrecision); | code.precision(kPrecision); | ||||
| // call the op function | // call the op function | ||||
| @@ -16,133 +16,27 @@ | |||||
| #include "coder/opcoders/nnacl/int8/fullconnection_int8_coder.h" | #include "coder/opcoders/nnacl/int8/fullconnection_int8_coder.h" | ||||
| #include "nnacl/int8/matmul_int8.h" | #include "nnacl/int8/matmul_int8.h" | ||||
| #include "coder/opcoders/file_collector.h" | |||||
| #include "coder/log.h" | #include "coder/log.h" | ||||
| #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" | |||||
| using mindspore::schema::PrimitiveType_FullConnection; | using mindspore::schema::PrimitiveType_FullConnection; | ||||
| namespace mindspore::lite::micro::nnacl { | namespace mindspore::lite::micro::nnacl { | ||||
| FullConnectionInt8Coder ::~FullConnectionInt8Coder() { | |||||
| FreeQuantParam(); | |||||
| filter_tensor_ = nullptr; | |||||
| bias_tensor_ = nullptr; | |||||
| pack_a_ptr_ = nullptr; | |||||
| pack_b_ptr_ = nullptr; | |||||
| input_sums_ = nullptr; | |||||
| weight_bias_sums_ = nullptr; | |||||
| bias_ptr_ = nullptr; | |||||
| } | |||||
| int FullConnectionInt8Coder::MallocQuantParam() { | |||||
| filter_tensor_ = input_tensors_.at(kWeightIndex); | |||||
| MS_CHECK_PTR(filter_tensor_); | |||||
| std::vector<QuantArg> weight_quant_params = filter_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!filter_tensor_->shape().empty(), "filter tensor shape is empty"); | |||||
| int col = filter_tensor_->shape().front(); | |||||
| filter_per_channel_ = (weight_quant_params.size() > 1); | |||||
| init_size_ = filter_per_channel_ ? col : 1; | |||||
| quant_.filter_scale_ = reinterpret_cast<float *>(malloc(init_size_ * sizeof(float))); | |||||
| MS_CHECK_PTR(quant_.filter_scale_); | |||||
| quant_.filter_zp_ = reinterpret_cast<int32_t *>(malloc(init_size_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.filter_zp_); | |||||
| quant_.left_shift_ = reinterpret_cast<int32_t *>(malloc(init_size_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.left_shift_); | |||||
| quant_.right_shift_ = reinterpret_cast<int32_t *>(malloc(init_size_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.right_shift_); | |||||
| quant_.quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(init_size_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.quant_multiplier_); | |||||
| return RET_OK; | |||||
| } | |||||
| void FullConnectionInt8Coder::FreeQuantParam() { | |||||
| if (quant_.filter_scale_ != nullptr) { | |||||
| free(quant_.filter_scale_); | |||||
| quant_.filter_scale_ = nullptr; | |||||
| } | |||||
| if (quant_.filter_zp_ != nullptr) { | |||||
| free(quant_.filter_zp_); | |||||
| quant_.filter_zp_ = nullptr; | |||||
| } | |||||
| if (quant_.left_shift_ != nullptr) { | |||||
| free(quant_.left_shift_); | |||||
| quant_.left_shift_ = nullptr; | |||||
| } | |||||
| if (quant_.right_shift_ != nullptr) { | |||||
| free(quant_.right_shift_); | |||||
| quant_.right_shift_ = nullptr; | |||||
| } | |||||
| if (quant_.quant_multiplier_ != nullptr) { | |||||
| free(quant_.quant_multiplier_); | |||||
| quant_.quant_multiplier_ = nullptr; | |||||
| } | |||||
| } | |||||
| void FullConnectionInt8Coder::InitParam() { | |||||
| int FullConnectionInt8Coder::ReSize(CoderContext *const context) { | |||||
| int row = 1; | int row = 1; | ||||
| int out_put_tensor_size = static_cast<int>(output_tensor_->shape().size()); | |||||
| for (int i = 0; i < out_put_tensor_size - 1; ++i) { | |||||
| for (size_t i = 0; i < output_tensor_->shape().size() - 1; ++i) { | |||||
| row *= (output_tensor_->shape()).at(i); | row *= (output_tensor_->shape()).at(i); | ||||
| } | } | ||||
| fc_param_->row_ = row; | |||||
| fc_param_->col_ = output_tensor_->shape().back(); | |||||
| fc_param_->deep_ = filter_tensor_->shape().at(1); | |||||
| fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | |||||
| fc_param_->col_4_ = UP_ROUND(fc_param_->col_, C4NUM); | |||||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); | |||||
| fc_param_->deep_16_ = UP_ROUND(fc_param_->deep_, C16NUM); | |||||
| thread_count_ = MSMIN(thread_num_, UP_DIV(fc_param_->col_4_, C4NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_4_, C4NUM), thread_count_); | |||||
| } | |||||
| int FullConnectionInt8Coder::ReSize(CoderContext *const context) { | |||||
| InitParam(); | |||||
| pack_a_ptr_size_ = static_cast<size_t>(fc_param_->row_4_ * fc_param_->deep_16_ * sizeof(int8_t)); | |||||
| pack_a_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, pack_a_ptr_size_, kOfflinePackWeight)); | |||||
| MS_CHECK_PTR(pack_a_ptr_); | |||||
| pack_b_ptr_size_ = static_cast<size_t>(fc_param_->col_4_ * fc_param_->deep_16_ * sizeof(int8_t)); | |||||
| weight_bias_sums_size_ = static_cast<size_t>(fc_param_->col_4_ * sizeof(int)); | |||||
| if (fc_param_->b_const_) { | |||||
| pack_b_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, kOnlineSize, kOnlinePackWeight)); | |||||
| MS_CHECK_PTR(pack_b_ptr_); | |||||
| weight_bias_sums_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); | |||||
| } else { | |||||
| pack_b_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, pack_b_ptr_size_, kOfflinePackWeight)); | |||||
| MS_CHECK_PTR(pack_b_ptr_); | |||||
| weight_bias_sums_ = | |||||
| reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, weight_bias_sums_size_, kOfflinePackWeight)); | |||||
| } | |||||
| MS_CHECK_PTR(weight_bias_sums_); | |||||
| input_sums_size_ = static_cast<size_t>(fc_param_->row_4_ * sizeof(int)); | |||||
| input_sums_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, input_sums_size_, kOfflinePackWeight)); | |||||
| MS_CHECK_PTR(input_sums_); | |||||
| if (input_tensors_.size() == kInputSize2) { | |||||
| bias_ptr_size_ = static_cast<size_t>(fc_param_->col_4_ * sizeof(int)); | |||||
| bias_ptr_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); | |||||
| MS_CHECK_PTR(bias_ptr_); | |||||
| } else { | |||||
| bias_ptr_ = nullptr; | |||||
| } | |||||
| NNaclInt8Serializer init_code; | |||||
| if (input_tensors_.size() == kInputSize2) { | |||||
| init_code.CodeMallocExpression(bias_ptr_, bias_ptr_size_); | |||||
| init_code.CodeFunction("memset", bias_ptr_, 0, bias_ptr_size_); | |||||
| init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, bias_ptr_); | |||||
| } | |||||
| if (fc_param_->b_const_) { | |||||
| init_code.CodeMallocExpression(pack_b_ptr_, pack_b_ptr_size_); | |||||
| init_code.CodeMallocExpression(weight_bias_sums_, weight_bias_sums_size_); | |||||
| init_code.CodeFunction("RowMajor2Row16x4MajorInt8", filter_tensor_, pack_b_ptr_, fc_param_->col_, fc_param_->deep_); | |||||
| init_code.CodeFunction("CalcWeightBiasSums", filter_tensor_, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, | |||||
| quant_.filter_zp_, bias_ptr_, weight_bias_sums_, ColMajor, filter_per_channel_); | |||||
| } | |||||
| context->AppendInitCode(init_code.str()); | |||||
| param_->row_ = row; | |||||
| param_->col_ = output_tensor_->shape().back(); | |||||
| param_->deep_ = (filter_tensor_->shape()).at(1); | |||||
| MS_CHECK_RET_CODE(MatMulBaseInt8Coder::ReSize(context), "MatMulBaseInt8Coder::ReSize is nullptr"); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int FullConnectionInt8Coder::Init() { | |||||
| fc_param_ = reinterpret_cast<MatMulParameter *>(parameter_); | |||||
| int FullConnectionInt8Coder::Prepare(CoderContext *const context) { | |||||
| // only support one thread currently | |||||
| thread_count_ = thread_num_; | |||||
| param_ = reinterpret_cast<MatMulParameter *>(parameter_); | |||||
| filter_tensor_ = input_tensors_.at(kWeightIndex); | filter_tensor_ = input_tensors_.at(kWeightIndex); | ||||
| MS_CHECK_PTR(filter_tensor_); | MS_CHECK_PTR(filter_tensor_); | ||||
| if (input_tensors_.size() == kInputSize2) { | if (input_tensors_.size() == kInputSize2) { | ||||
| @@ -150,91 +44,16 @@ int FullConnectionInt8Coder::Init() { | |||||
| MS_CHECK_PTR(bias_tensor_); | MS_CHECK_PTR(bias_tensor_); | ||||
| MS_CHECK_PTR(bias_tensor_->data_c()); | MS_CHECK_PTR(bias_tensor_->data_c()); | ||||
| } | } | ||||
| fc_param_->a_const_ = (input_tensor_->data_c() != nullptr); | |||||
| fc_param_->b_const_ = (filter_tensor_->data_c() != nullptr); | |||||
| int ret = MallocQuantParam(); | |||||
| if (ret != RET_OK) { | |||||
| FreeQuantParam(); | |||||
| return ret; | |||||
| } | |||||
| std::vector<QuantArg> in_quant_params = input_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!in_quant_params.empty(), "in_quant_params empty is empty"); | |||||
| quant_.input_.zp_ = in_quant_params.front().zeroPoint; | |||||
| quant_.input_.scale_ = static_cast<float>(in_quant_params.front().scale); | |||||
| std::vector<QuantArg> out_quant_params = output_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!out_quant_params.empty(), "out_quant_params empty is empty"); | |||||
| quant_.output_.zp_ = out_quant_params.front().zeroPoint; | |||||
| quant_.output_.scale_ = static_cast<float>(out_quant_params.front().scale); | |||||
| int weight_quant_num = filter_per_channel_ ? static_cast<int>(filter_tensor_->shape().front()) : 1; | |||||
| std::vector<QuantArg> weight_quant_params = filter_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!weight_quant_params.empty(), "weight_quant_params empty is empty"); | |||||
| for (int i = 0; i < weight_quant_num; i++) { | |||||
| quant_.filter_zp_[i] = weight_quant_params[i].zeroPoint; | |||||
| quant_.filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale); | |||||
| } | |||||
| for (int i = 0; i < weight_quant_num; ++i) { | |||||
| auto in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]); | |||||
| double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_); | |||||
| QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], | |||||
| &quant_.right_shift_[i]); | |||||
| } | |||||
| CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, | |||||
| quant_.output_.zp_, quant_.output_.scale_, &quant_.out_act_min_, | |||||
| &quant_.out_act_max_); | |||||
| return RET_OK; | |||||
| } | |||||
| int FullConnectionInt8Coder::Prepare(CoderContext *const context) { | |||||
| // only support one thread currently | |||||
| thread_count_ = thread_num_; | |||||
| MS_CHECK_RET_CODE(Init(), "FullConnectionInt8Coder init failed"); | |||||
| param_->batch = 1; | |||||
| param_->a_transpose_ = false; | |||||
| param_->b_transpose_ = true; | |||||
| MatMulBaseInt8Coder::InitParameter(); | |||||
| MS_CHECK_RET_CODE(MatMulBaseInt8Coder::Init(), "Init failed"); | |||||
| return ReSize(context); | return ReSize(context); | ||||
| } | } | ||||
| int FullConnectionInt8Coder::DoCode(CoderContext *const context) { | int FullConnectionInt8Coder::DoCode(CoderContext *const context) { | ||||
| Collect(context, {"nnacl/common_func.h", "nnacl/int8/common_func_int8.h", "nnacl/int8/matmul_int8.h"}, | |||||
| {"common_func.c", "common_func_int8.c", "matmul_int8.c"}); | |||||
| NNaclInt8Serializer code; | |||||
| code.precision(kPrecision); | |||||
| code.CodeFunction("memset", input_sums_, 0, input_sums_size_); | |||||
| code.CodeFunction("memset", pack_a_ptr_, 0, pack_a_ptr_size_); | |||||
| code.CodeFunction("RowMajor2Row16x4MajorInt8", input_tensor_, pack_a_ptr_, fc_param_->row_, fc_param_->deep_); | |||||
| int32_t tmp_weight_zp = filter_per_channel_ ? 1 : quant_.filter_zp_[0]; | |||||
| code.CodeFunction("CalcInputSums", input_tensor_, fc_param_->row_, fc_param_->deep_, tmp_weight_zp, input_sums_, | |||||
| RowMajor); | |||||
| if (!fc_param_->b_const_) { | |||||
| code.CodeFunction("memset", pack_b_ptr_, 0, pack_b_ptr_size_); | |||||
| code.CodeFunction("memset", weight_bias_sums_, 0, weight_bias_sums_size_); | |||||
| code.CodeFunction("RowMajor2Row16x4MajorInt8", filter_tensor_, pack_b_ptr_, fc_param_->col_, fc_param_->deep_); | |||||
| code.CodeFunction("CalcWeightBiasSums", filter_tensor_, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, | |||||
| quant_.filter_zp_, bias_ptr_, weight_bias_sums_, ColMajor, filter_per_channel_); | |||||
| } | |||||
| int stride = thread_stride_ * C4NUM; | |||||
| int res_stride = fc_param_->col_; | |||||
| int cur_oc = MSMIN(stride, res_stride); | |||||
| if (cur_oc <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| int32_t *cur_left = quant_.left_shift_; | |||||
| int32_t *cur_right = quant_.right_shift_; | |||||
| int32_t *cur_mul = quant_.quant_multiplier_; | |||||
| int32_t *cur_zp = quant_.filter_zp_; | |||||
| code.CodeArray("cur_left_shift", cur_left, init_size_, true); | |||||
| code.CodeArray("cur_right_shift", cur_right, init_size_, true); | |||||
| code.CodeArray("cur_multiplier", cur_mul, init_size_, true); | |||||
| code.CodeArray("cur_filter_zp", cur_zp, init_size_, true); | |||||
| code.CodeFunction("MatmulInt8Opt", pack_a_ptr_, pack_b_ptr_, output_tensor_->data_c(), fc_param_->row_, cur_oc, | |||||
| fc_param_->deep_16_, input_sums_, weight_bias_sums_, quant_.out_act_min_, quant_.out_act_max_, | |||||
| quant_.output_.zp_, "&cur_multiplier", "&cur_left_shift", "&cur_right_shift", fc_param_->col_, | |||||
| filter_per_channel_, "&cur_filter_zp"); | |||||
| MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called"; | |||||
| context->AppendCode(code.str()); | |||||
| MS_CHECK_RET_CODE(MatMulBaseInt8Coder::DoCode(context), "matmul int8 do code failed"); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -20,49 +20,25 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "coder/opcoders/op_coder.h" | |||||
| #include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" | |||||
| #include "nnacl/int8/quantize.h" | #include "nnacl/int8/quantize.h" | ||||
| #include "nnacl/matmul_parameter.h" | #include "nnacl/matmul_parameter.h" | ||||
| namespace mindspore::lite::micro::nnacl { | namespace mindspore::lite::micro::nnacl { | ||||
| class FullConnectionInt8Coder final : public OperatorCoder { | |||||
| class FullConnectionInt8Coder final : public MatMulBaseInt8Coder { | |||||
| public: | public: | ||||
| FullConnectionInt8Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | FullConnectionInt8Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | ||||
| const Model::Node *node, size_t node_index, Target target) | const Model::Node *node, size_t node_index, Target target) | ||||
| : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} | |||||
| : MatMulBaseInt8Coder(in_tensors, out_tensors, node, node_index, target) {} | |||||
| ~FullConnectionInt8Coder() override; | |||||
| ~FullConnectionInt8Coder() override = default; | |||||
| int Prepare(CoderContext *const context) override; | int Prepare(CoderContext *const context) override; | ||||
| int DoCode(CoderContext *const context) override; | int DoCode(CoderContext *const context) override; | ||||
| private: | private: | ||||
| int Init(); | |||||
| int ReSize(CoderContext *const context); | |||||
| int MallocQuantParam(); | |||||
| void FreeQuantParam(); | |||||
| void InitParam(); | |||||
| private: | |||||
| MatmulQuantParameter quant_{0}; | |||||
| MatMulParameter *fc_param_{nullptr}; | |||||
| Tensor *filter_tensor_{nullptr}; | |||||
| Tensor *bias_tensor_{nullptr}; | |||||
| size_t pack_a_ptr_size_{0}; | |||||
| int8_t *pack_a_ptr_ = nullptr; | |||||
| size_t pack_b_ptr_size_{0}; | |||||
| int8_t *pack_b_ptr_{nullptr}; | |||||
| size_t input_sums_size_{0}; | |||||
| int *input_sums_{nullptr}; | |||||
| size_t weight_bias_sums_size_{0}; | |||||
| int *weight_bias_sums_{nullptr}; | |||||
| size_t bias_ptr_size_{0}; | |||||
| int *bias_ptr_{nullptr}; | |||||
| int thread_count_{1}; | |||||
| int thread_stride_{0}; | |||||
| bool filter_per_channel_{true}; | |||||
| int init_size_{0}; | |||||
| int ReSize(CoderContext *const context) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::micro::nnacl | } // namespace mindspore::lite::micro::nnacl | ||||
| #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_FULLCONNECTION_INT8_CODER_H_ | #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_FULLCONNECTION_INT8_CODER_H_ | ||||
| @@ -0,0 +1,255 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" | |||||
| #include "coder/opcoders/file_collector.h" | |||||
| namespace mindspore::lite::micro::nnacl { | |||||
| int MatMulBaseInt8Coder::ReSize(CoderContext *const context) { | |||||
| ResizeParameter(); | |||||
| if (InitTmpBuffer() != RET_OK) { | |||||
| FreeQuantParam(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int MatMulBaseInt8Coder::InitTmpBuffer() { | |||||
| a_pack_ptr_size_ = param_->row_align_ * param_->deep_16_ * sizeof(int8_t); | |||||
| pack_a_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, a_pack_ptr_size_, kWorkspace)); | |||||
| MS_CHECK_PTR(pack_a_ptr_); | |||||
| b_pack_ptr_size_ = param_->batch * param_->col_align_ * param_->deep_16_ * sizeof(int8_t); | |||||
| if (param_->b_const_) { | |||||
| pack_b_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, kOnlineSize, kOnlinePackWeight)); | |||||
| } else { | |||||
| pack_b_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, b_pack_ptr_size_, kWorkspace)); | |||||
| } | |||||
| MS_CHECK_PTR(pack_b_ptr_); | |||||
| input_sums_size_ = static_cast<size_t>(param_->row_align_ * sizeof(int)); | |||||
| input_sums_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt32, input_sums_size_, kWorkspace)); | |||||
| MS_CHECK_PTR(input_sums_); | |||||
| weight_bias_sums_size_ = static_cast<size_t>(param_->batch * param_->col_align_ * sizeof(int)); | |||||
| if (param_->b_const_) { | |||||
| weight_bias_sums_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt32, kOnlineSize, kOnlinePackWeight)); | |||||
| } else { | |||||
| weight_bias_sums_ = | |||||
| reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt32, weight_bias_sums_size_, kWorkspace)); | |||||
| } | |||||
| MS_CHECK_PTR(weight_bias_sums_); | |||||
| return RET_OK; | |||||
| } | |||||
| MatMulBaseInt8Coder::~MatMulBaseInt8Coder() { FreeQuantParam(); } | |||||
| void MatMulBaseInt8Coder::ResizeParameter() { | |||||
| param_->row_align_ = UP_ROUND(param_->row_, C4NUM); | |||||
| param_->col_align_ = UP_ROUND(param_->col_, C4NUM); | |||||
| param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); | |||||
| thread_count_ = MSMIN(param_->op_parameter_.thread_num_, UP_DIV(param_->col_align_, C4NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_); | |||||
| } | |||||
| void MatMulBaseInt8Coder::FreeQuantParam() { | |||||
| if (quant_.filter_scale_ != nullptr) { | |||||
| free(quant_.filter_scale_); | |||||
| quant_.filter_scale_ = nullptr; | |||||
| } | |||||
| if (quant_.filter_zp_ != nullptr) { | |||||
| free(quant_.filter_zp_); | |||||
| quant_.filter_zp_ = nullptr; | |||||
| } | |||||
| if (quant_.left_shift_ != nullptr) { | |||||
| free(quant_.left_shift_); | |||||
| quant_.left_shift_ = nullptr; | |||||
| } | |||||
| if (quant_.right_shift_ != nullptr) { | |||||
| free(quant_.right_shift_); | |||||
| quant_.right_shift_ = nullptr; | |||||
| } | |||||
| if (quant_.quant_multiplier_ != nullptr) { | |||||
| free(quant_.quant_multiplier_); | |||||
| quant_.quant_multiplier_ = nullptr; | |||||
| } | |||||
| } | |||||
| int MatMulBaseInt8Coder::MallocQuantParam() { | |||||
| std::vector<QuantArg> weight_quant_params = filter_tensor_->quant_params(); | |||||
| int col = filter_tensor_->shape().front(); | |||||
| filter_per_channel_ = (weight_quant_params.size() > 1); | |||||
| weight_quant_num_ = filter_per_channel_ ? col : 1; | |||||
| quant_.filter_scale_ = reinterpret_cast<float *>(malloc(weight_quant_num_ * sizeof(float))); | |||||
| MS_CHECK_PTR(quant_.filter_scale_); | |||||
| quant_.filter_zp_ = reinterpret_cast<int32_t *>(malloc(weight_quant_num_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.filter_zp_); | |||||
| quant_.left_shift_ = reinterpret_cast<int32_t *>(malloc(weight_quant_num_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.left_shift_); | |||||
| quant_.right_shift_ = reinterpret_cast<int32_t *>(malloc(weight_quant_num_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.right_shift_); | |||||
| quant_.quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(weight_quant_num_ * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_.quant_multiplier_); | |||||
| return RET_OK; | |||||
| } | |||||
| int MatMulBaseInt8Coder::InitQuantParam() { | |||||
| std::vector<QuantArg> in_quant_params = input_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!in_quant_params.empty(), "in_quant_params is empty"); | |||||
| quant_.input_.zp_ = in_quant_params.front().zeroPoint; | |||||
| quant_.input_.scale_ = static_cast<float>(in_quant_params.front().scale); | |||||
| std::vector<QuantArg> out_quant_params = output_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!out_quant_params.empty(), "out_quant_params is empty"); | |||||
| quant_.output_.zp_ = out_quant_params.front().zeroPoint; | |||||
| quant_.output_.scale_ = static_cast<float>(out_quant_params.front().scale); | |||||
| std::vector<QuantArg> weight_quant_params = filter_tensor_->quant_params(); | |||||
| for (int i = 0; i < weight_quant_num_; i++) { | |||||
| quant_.filter_zp_[i] = weight_quant_params[i].zeroPoint; | |||||
| quant_.filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale); | |||||
| } | |||||
| for (int i = 0; i < weight_quant_num_; ++i) { | |||||
| const auto in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]); | |||||
| double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_); | |||||
| QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], | |||||
| &quant_.right_shift_[i]); | |||||
| } | |||||
| CalculateActivationRangeQuantized(param_->act_type_ == ActType_Relu, param_->act_type_ == ActType_Relu6, | |||||
| quant_.output_.zp_, quant_.output_.scale_, &quant_.out_act_min_, | |||||
| &quant_.out_act_max_); | |||||
| return RET_OK; | |||||
| } | |||||
| void MatMulBaseInt8Coder::InitParameter() { | |||||
| param_->a_const_ = (input_tensor_ != nullptr); | |||||
| param_->b_const_ = (filter_tensor_ != nullptr); | |||||
| } | |||||
| int MatMulBaseInt8Coder::InitBias() { | |||||
| if (bias_tensor_ != nullptr) { | |||||
| int max_bias_data_elements = UP_ROUND(bias_tensor_->ElementsNum(), C4NUM); | |||||
| // to pack to init | |||||
| bias_ptr_size_ = static_cast<size_t>(max_bias_data_elements * sizeof(int)); | |||||
| bias_ptr_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt32, kOnlineSize, kOnlinePackWeight)); | |||||
| MS_CHECK_PTR(bias_ptr_); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int MatMulBaseInt8Coder::Init() { | |||||
| if (MallocQuantParam() != RET_OK) { | |||||
| FreeQuantParam(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_CHECK_RET_CODE(InitQuantParam(), "matmul int8 init quant_param failed"); | |||||
| if (InitBias() != RET_OK) { | |||||
| FreeQuantParam(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int MatMulBaseInt8Coder::Prepare(CoderContext *const context) { return RET_OK; } | |||||
| int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { | |||||
| Collect(context, | |||||
| {"nnacl/common_func.h", "nnacl/int8/common_func_int8.h", "nnacl/int8/matmul_int8.h", | |||||
| "wrapper/int8/matmul_int8_wrapper.h"}, | |||||
| {"common_func.c", "common_func_int8.c", "matmul_int8.c", "matmul_int8_wrapper.c"}); | |||||
| std::string value_str_end = ";\n"; | |||||
| NNaclInt8Serializer init_code; | |||||
| NNaclInt8Serializer code; | |||||
| if (bias_ptr_) { | |||||
| init_code.CodeMallocExpression(bias_ptr_, bias_ptr_size_); | |||||
| init_code.CodeFunction("memset", bias_ptr_, 0, bias_ptr_size_); | |||||
| init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, bias_ptr_size_); | |||||
| } | |||||
| if (param_->b_const_) { | |||||
| init_code.CodeMallocExpression(weight_bias_sums_, weight_bias_sums_size_); | |||||
| init_code.CodeFunction("memset", weight_bias_sums_, 0, weight_bias_sums_size_); | |||||
| init_code.CodeMallocExpression(pack_b_ptr_, b_pack_ptr_size_); | |||||
| init_code.CodeFunction("memset", pack_b_ptr_, 0, b_pack_ptr_size_); | |||||
| init_code.CodeArray("init_filter_zp", quant_.filter_zp_, weight_quant_num_); | |||||
| init_code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, | |||||
| param_->deep_, param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, | |||||
| "init_filter_zp", bias_ptr_, param_->b_transpose_, filter_per_channel_); | |||||
| } else { | |||||
| code.CodeFunction("InitInt8MatrixB", filter_tensor_, weight_bias_sums_, pack_b_ptr_, param_->batch, param_->deep_, | |||||
| param_->col_, param_->col_align_, param_->deep_16_, quant_.input_.zp_, "init_filter_zp", | |||||
| bias_ptr_, param_->b_transpose_, filter_per_channel_); | |||||
| } | |||||
| int task_id = 0; | |||||
| std::string a_ptr_str = allocator_->GetRuntimeAddr(input_tensor_); | |||||
| std::string c_ptr_str = allocator_->GetRuntimeAddr(output_tensor_); | |||||
| std::string pack_b_ptr_str = allocator_->GetRuntimeAddr(pack_b_ptr_); | |||||
| std::string weight_bias_sums_str = allocator_->GetRuntimeAddr(weight_bias_sums_); | |||||
| code.precision(kPrecision); | |||||
| std::string tmp_weight_zp = | |||||
| "int32_t tmp_weight_zp = " + (filter_per_channel_ ? std::to_string(1) : std::to_string(quant_.filter_zp_[0])); | |||||
| code << tmp_weight_zp << value_str_end; | |||||
| for (int i = 0; i < param_->batch; i++) { | |||||
| std::string current_src_a = a_ptr_str + "+" + std::to_string(i * param_->row_ * param_->deep_); | |||||
| if (param_->a_transpose_) { | |||||
| code.CodeFunction("RowMajor2Col16x4MajorInt8", current_src_a, param_->deep_, param_->row_, pack_a_ptr_); | |||||
| code.CodeFunction("CalcInputSums", current_src_a, param_->row_, param_->deep_, "tmp_weight_zp", input_sums_, | |||||
| ColMajor); | |||||
| } else { | |||||
| code.CodeFunction("RowMajor2Row16x4MajorInt8", current_src_a, pack_a_ptr_, param_->row_, param_->deep_); | |||||
| code.CodeFunction("CalcInputSums", current_src_a, param_->row_, param_->deep_, "tmp_weight_zp", input_sums_, | |||||
| RowMajor); | |||||
| } | |||||
| std::string batch_b_ptr_str = pack_b_ptr_str + "+" + std::to_string(i * param_->col_align_ * param_->deep_16_); | |||||
| std::string batch_c_ptr_str = c_ptr_str + "+" + std::to_string(i * param_->row_ * param_->col_); | |||||
| int stride = thread_stride_ * C4NUM; | |||||
| int cur_stride = task_id * stride; | |||||
| int res_stride = param_->col_ - cur_stride; | |||||
| int cur_oc = MSMIN(stride, res_stride); | |||||
| if (cur_oc <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| code.CodeStruct("matmul_quant_parameter", quant_, weight_quant_num_); | |||||
| std::string cur_left = "int32_t *cur_left = matmul_quant_parameter.left_shift_"; | |||||
| std::string cur_right = "int32_t *cur_right = matmul_quant_parameter.right_shift_"; | |||||
| std::string cur_mul = "int32_t *cur_mul = matmul_quant_parameter.quant_multiplier_ "; | |||||
| std::string cur_zp = "int32_t *cur_zp = matmul_quant_parameter.filter_zp_ "; | |||||
| if (filter_per_channel_) { | |||||
| code << cur_left << " + " << cur_stride << value_str_end; | |||||
| code << cur_right << " + " << cur_stride << value_str_end; | |||||
| code << cur_mul << " + " << cur_stride << value_str_end; | |||||
| code << cur_zp << " + " << cur_stride << value_str_end; | |||||
| } else { | |||||
| code << cur_left << value_str_end; | |||||
| code << cur_right << value_str_end; | |||||
| code << cur_mul << value_str_end; | |||||
| code << cur_zp << value_str_end; | |||||
| } | |||||
| std::string batch_b_ptr_str_final = batch_b_ptr_str + " + " + std::to_string(cur_stride * param_->deep_16_); | |||||
| std::string batch_c_ptr_final = batch_c_ptr_str + "+" + std::to_string(cur_stride); | |||||
| std::string weight_bias_sums_str_final = weight_bias_sums_str + "+" + std::to_string(cur_stride); | |||||
| code.CodeFunction("MatmulInt8Opt", pack_a_ptr_, batch_b_ptr_str_final, batch_c_ptr_final, param_->row_, cur_oc, | |||||
| param_->deep_16_, input_sums_, weight_bias_sums_str_final, quant_.out_act_min_, | |||||
| quant_.out_act_max_, quant_.output_.zp_, "cur_mul", "cur_left", "cur_right", param_->col_, | |||||
| filter_per_channel_, "cur_zp"); | |||||
| } | |||||
| MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called"; | |||||
| context->AppendInitCode(init_code.str()); | |||||
| context->AppendCode(code.str()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::lite::micro::nnacl | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ | |||||
| #define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ | |||||
| #include <vector> | |||||
| #include "coder/opcoders/op_coder.h" | |||||
| #include "nnacl/matmul_parameter.h" | |||||
| namespace mindspore::lite::micro::nnacl { | |||||
| class MatMulBaseInt8Coder : public OperatorCoder { | |||||
| public: | |||||
| MatMulBaseInt8Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||||
| const Model::Node *node, size_t node_index, Target target) | |||||
| : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} | |||||
| ~MatMulBaseInt8Coder() override; | |||||
| int Prepare(CoderContext *const context) override; | |||||
| int DoCode(CoderContext *const context) override; | |||||
| protected: | |||||
| int Init(); | |||||
| void InitParameter(); | |||||
| virtual int ReSize(CoderContext *const context); | |||||
| private: | |||||
| void ResizeParameter(); | |||||
| int MallocQuantParam(); | |||||
| void FreeQuantParam(); | |||||
| int InitQuantParam(); | |||||
| int InitBias(); | |||||
| int InitTmpBuffer(); | |||||
| protected: | |||||
| bool filter_per_channel_{true}; | |||||
| int thread_count_{1}; | |||||
| int thread_stride_{0}; | |||||
| MatmulQuantParameter quant_{}; | |||||
| Tensor *filter_tensor_{nullptr}; | |||||
| Tensor *bias_tensor_{nullptr}; | |||||
| MatMulParameter *param_{nullptr}; | |||||
| size_t a_pack_ptr_size_{0}; | |||||
| size_t b_pack_ptr_size_{0}; | |||||
| int8_t *pack_a_ptr_{nullptr}; | |||||
| int8_t *pack_b_ptr_{nullptr}; | |||||
| size_t bias_ptr_size_{0}; | |||||
| int *bias_ptr_{nullptr}; | |||||
| size_t input_sums_size_{0}; | |||||
| int *input_sums_{nullptr}; | |||||
| size_t weight_bias_sums_size_{0}; | |||||
| int *weight_bias_sums_{nullptr}; | |||||
| private: | |||||
| int weight_quant_num_{0}; | |||||
| }; | |||||
| } // namespace mindspore::lite::micro::nnacl | |||||
| #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_BASE_INT8_CODER_H_ | |||||
| @@ -15,117 +15,10 @@ | |||||
| */ | */ | ||||
| #include "coder/opcoders/nnacl/int8/matmul_int8_coder.h" | #include "coder/opcoders/nnacl/int8/matmul_int8_coder.h" | ||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" | |||||
| #include "coder/opcoders/file_collector.h" | |||||
| #include "coder/opcoders/op_coder.h" | |||||
| using mindspore::schema::PrimitiveType_MatMul; | |||||
| namespace mindspore::lite::micro::nnacl { | namespace mindspore::lite::micro::nnacl { | ||||
| int MatMulInt8Coder::ReSize(CoderContext *const context) { | |||||
| int batch = 1; | |||||
| std::vector<int> x_shape = input_tensor_->shape(); | |||||
| std::vector<int> o_shape = output_tensor_->shape(); | |||||
| if (x_shape.size() <= 2 || o_shape.size() <= 2) { | |||||
| MS_LOG(ERROR) << "x_shape.size() or o_shape.size() is less than two"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (size_t i = 0; i < x_shape.size() - 2; ++i) { | |||||
| batch *= x_shape.at(i); | |||||
| } | |||||
| params_->batch = batch; | |||||
| params_->row_ = o_shape.at(o_shape.size() - 2); | |||||
| params_->col_ = o_shape.at(o_shape.size() - 1); | |||||
| params_->deep_ = params_->a_transpose_ ? x_shape.at(x_shape.size() - 2) : x_shape.at(x_shape.size() - 1); | |||||
| params_->row_4_ = UP_ROUND(params_->row_, C4NUM); | |||||
| params_->col_4_ = UP_ROUND(params_->col_, C4NUM); | |||||
| params_->deep_16_ = UP_ROUND(params_->deep_, C16NUM); | |||||
| a_pack_ptr_size_ = static_cast<size_t>(params_->row_4_ * params_->deep_16_ * sizeof(int8_t)); | |||||
| a_pack_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, a_pack_ptr_size_, kOfflinePackWeight)); | |||||
| MS_CHECK_PTR(a_pack_ptr_); | |||||
| input_sums_size_ = static_cast<size_t>(params_->row_4_ * sizeof(int)); | |||||
| input_sums_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, input_sums_size_, kOfflinePackWeight)); | |||||
| MS_CHECK_PTR(input_sums_); | |||||
| b_pack_batch_ptr_size_ = static_cast<size_t>(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)); | |||||
| if (params_->b_const_) { | |||||
| b_pack_batch_ptr_ = reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, kOnlineSize, kOnlinePackWeight)); | |||||
| MS_CHECK_PTR(b_pack_batch_ptr_); | |||||
| weight_bias_sums_batch_ = | |||||
| reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); | |||||
| } else { | |||||
| b_pack_batch_ptr_ = | |||||
| reinterpret_cast<int8_t *>(allocator_->Malloc(kNumberTypeInt8, b_pack_batch_ptr_size_, kOfflinePackWeight)); | |||||
| MS_CHECK_PTR(b_pack_batch_ptr_); | |||||
| weight_bias_sums_batch_ = | |||||
| reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, weight_bias_sums_batch_size_, kOfflinePackWeight)); | |||||
| } | |||||
| MS_CHECK_PTR(weight_bias_sums_batch_); | |||||
| if (input_tensors_.size() == 3) { | |||||
| bias_prt_size_ = static_cast<size_t>(params_->col_4_ * sizeof(int)); | |||||
| bias_ptr_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, kOnlineSize, kOnlinePackWeight)); | |||||
| MS_CHECK_PTR(bias_ptr_); | |||||
| } else { | |||||
| bias_ptr_ = nullptr; | |||||
| } | |||||
| thread_count_ = MSMIN(thread_num_, UP_DIV(params_->col_4_, C4NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, C4NUM), thread_count_); | |||||
| std::vector<QuantArg> params = input_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!params.empty(), "params is empty"); | |||||
| quant_params_.input_.zp_ = params.front().zeroPoint; | |||||
| quant_params_.input_.scale_ = static_cast<float>(params.front().scale); | |||||
| params = filter_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!params.empty(), "params is empty"); | |||||
| quant_params_.weight_.zp_ = params.front().zeroPoint; | |||||
| quant_params_.weight_.scale_ = static_cast<float>(params.front().scale); | |||||
| params = output_tensor_->quant_params(); | |||||
| MS_CHECK_TRUE(!params.empty(), "params is empty"); | |||||
| quant_params_.output_.zp_ = params.front().zeroPoint; | |||||
| quant_params_.output_.scale_ = static_cast<float>(params.front().scale); | |||||
| double real_multiplier = quant_params_.input_.scale_ * quant_params_.weight_.scale_ / quant_params_.output_.scale_; | |||||
| QuantizeRoundParameterWithDoublePrecision(real_multiplier, quant_params_.quant_multiplier_, quant_params_.left_shift_, | |||||
| quant_params_.right_shift_); | |||||
| if (params_->b_const_) { | |||||
| NNaclInt8Serializer init_code; | |||||
| if (bias_ptr_) { | |||||
| init_code.CodeMallocExpression(bias_ptr_, bias_prt_size_); | |||||
| init_code.CodeFunction("memset", bias_ptr_, 0, bias_prt_size_); | |||||
| init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_->data_c(), bias_prt_size_); | |||||
| } | |||||
| init_code.CodeMallocExpression(weight_bias_sums_batch_, weight_bias_sums_batch_size_); | |||||
| init_code.CodeFunction("memset", weight_bias_sums_batch_, 0, weight_bias_sums_batch_size_); | |||||
| init_code.CodeMallocExpression(b_pack_batch_ptr_, b_pack_batch_ptr_size_); | |||||
| init_code.CodeFunction("memset", b_pack_batch_ptr_, 0, b_pack_batch_ptr_size_); | |||||
| init_code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n"; | |||||
| init_code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_, | |||||
| params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_, | |||||
| quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_); | |||||
| context->AppendInitCode(init_code.str()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| MatMulInt8Coder::~MatMulInt8Coder() { | |||||
| if (quant_params_.quant_multiplier_ != nullptr) { | |||||
| free(quant_params_.quant_multiplier_); | |||||
| quant_params_.quant_multiplier_ = nullptr; | |||||
| } | |||||
| if (quant_params_.right_shift_ != nullptr) { | |||||
| free(quant_params_.right_shift_); | |||||
| quant_params_.right_shift_ = nullptr; | |||||
| } | |||||
| if (quant_params_.left_shift_ != nullptr) { | |||||
| free(quant_params_.left_shift_); | |||||
| quant_params_.left_shift_ = nullptr; | |||||
| } | |||||
| return; | |||||
| } | |||||
| int MatMulInt8Coder::Init() { | |||||
| params_ = reinterpret_cast<MatMulParameter *>(parameter_); | |||||
| int MatMulInt8Coder::Prepare(CoderContext *const context) { | |||||
| filter_tensor_ = input_tensors_.at(kWeightIndex); | filter_tensor_ = input_tensors_.at(kWeightIndex); | ||||
| MS_CHECK_PTR(filter_tensor_); | MS_CHECK_PTR(filter_tensor_); | ||||
| if (input_tensors_.size() == kInputSize2) { | if (input_tensors_.size() == kInputSize2) { | ||||
| @@ -133,85 +26,34 @@ int MatMulInt8Coder::Init() { | |||||
| MS_CHECK_PTR(bias_tensor_); | MS_CHECK_PTR(bias_tensor_); | ||||
| MS_CHECK_PTR(bias_tensor_->data_c()); | MS_CHECK_PTR(bias_tensor_->data_c()); | ||||
| } | } | ||||
| params_->b_const_ = (filter_tensor_->data_c() != nullptr); | |||||
| quant_params_.quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_params_.quant_multiplier_); | |||||
| quant_params_.left_shift_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_params_.left_shift_); | |||||
| quant_params_.right_shift_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t))); | |||||
| MS_CHECK_PTR(quant_params_.right_shift_); | |||||
| return RET_OK; | |||||
| param_ = reinterpret_cast<MatMulParameter *>(parameter_); | |||||
| MatMulBaseInt8Coder::InitParameter(); | |||||
| MS_CHECK_RET_CODE(MatMulBaseInt8Coder::Init(), "ParallelLaunch failed"); | |||||
| return ReSize(context); | |||||
| } | } | ||||
| int MatMulInt8Coder::Prepare(CoderContext *const context) { | |||||
| MS_CHECK_RET_CODE(Init(), "MatMulInt8Coder Init failed"); | |||||
| MS_CHECK_RET_CODE(ReSize(context), "MatMulInt8Coder ReSize failed"); | |||||
| int MatMulInt8Coder::ReSize(CoderContext *const context) { | |||||
| int batch = 1; | |||||
| std::vector<int> x_shape = input_tensor_->shape(); | |||||
| std::vector<int> o_shape = output_tensor_->shape(); | |||||
| MS_CHECK_RET_CODE(x_shape.size() >= kBiasIndex, "x_shape size is less than two"); | |||||
| for (size_t i = 0; i < x_shape.size() - kBiasIndex; ++i) { | |||||
| batch *= x_shape[i]; | |||||
| } | |||||
| param_->batch = batch; | |||||
| MS_CHECK_RET_CODE(o_shape.size() >= kBiasIndex, "o_shape size is less than two"); | |||||
| param_->row_ = o_shape[o_shape.size() - kBiasIndex]; | |||||
| param_->col_ = o_shape[o_shape.size() - kWeightIndex]; | |||||
| param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kBiasIndex] : x_shape[x_shape.size() - kWeightIndex]; | |||||
| MS_CHECK_RET_CODE(MatMulBaseInt8Coder::ReSize(context), "MatMulBaseInt8Coder::ReSize is nullptr"); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int MatMulInt8Coder::DoCode(CoderContext *const context) { | int MatMulInt8Coder::DoCode(CoderContext *const context) { | ||||
| Collect(context, {"nnacl/common_func.h", "nnacl/int8/common_func_int8.h", "nnacl/int8/matmul_int8.h"}, | |||||
| {"common_func.c", "common_func_int8.c", "matmul_int8.c"}); | |||||
| std::string a_ptr_str = allocator_->GetRuntimeAddr(input_tensor_); | |||||
| std::string c_ptr_str = allocator_->GetRuntimeAddr(output_tensor_); | |||||
| int a_stride = params_->row_ * params_->deep_; | |||||
| int c_stride = params_->row_ * params_->col_; | |||||
| NNaclInt8Serializer code; | |||||
| code.precision(kPrecision); | |||||
| int task_id = 0; | |||||
| int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_4_, C4NUM) - task_id * thread_stride_); | |||||
| if (cur_oc <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n"; | |||||
| if (!params_->b_const_) { | |||||
| code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_, | |||||
| params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_, | |||||
| quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_); | |||||
| } | |||||
| std::string b_batch_str = allocator_->GetRuntimeAddr(b_pack_batch_ptr_); | |||||
| std::string weight_bias_sums_batch_str = allocator_->GetRuntimeAddr(weight_bias_sums_batch_); | |||||
| code.CodeFunction("memset", input_sums_, 0, input_sums_size_); | |||||
| code.CodeFunction("memset", a_pack_ptr_, 0, a_pack_ptr_size_); | |||||
| code << "for (int i = 0; i < " << params_->batch << "; ++i) {\n"; | |||||
| code << " int8_t* cur_a_ptr = " << a_ptr_str << " + i * " << a_stride << ";\n"; | |||||
| if (params_->a_transpose_) { | |||||
| code.CodeFunction("RowMajor2Col16x4MajorInt8", "cur_a_ptr", params_->deep_, params_->row_, a_pack_ptr_); | |||||
| code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->deep_, quant_params_.weight_.zp_, input_sums_, ColMajor); | |||||
| } else { | |||||
| code.CodeFunction("RowMajor2Row16x4MajorInt8", "cur_a_ptr", a_pack_ptr_, params_->row_, params_->deep_); | |||||
| code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->row_, params_->deep_, quant_params_.weight_.zp_, | |||||
| input_sums_, RowMajor); | |||||
| } | |||||
| code << " b_pack_ptr_ = " << b_batch_str << " + i * " << params_->col_4_ * params_->deep_16_ << ";\n"; | |||||
| code << " weight_bias_sums_ = " << weight_bias_sums_batch_str << " + i * " << params_->col_4_ << ";\n"; | |||||
| code << " c_ptr_ = " << c_ptr_str << " + i * " << c_stride << ";\n"; | |||||
| int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM); | |||||
| code << " int8_t* cur_b = b_pack_ptr_ + " << task_id * thread_stride_ * C4NUM * params_->deep_16_ << ";\n"; | |||||
| code << " int32_t* cur_bias = weight_bias_sums_ + " << task_id * thread_stride_ * C4NUM << ";\n"; | |||||
| code << " int8_t *cur_c = c_ptr_ + " << task_id * thread_stride_ * C4NUM << ";\n"; | |||||
| code << " static const int left_shift = " << quant_params_.left_shift_[0] << ";\n"; | |||||
| code << " static const int right_shift = " << quant_params_.right_shift_[0] << ";\n"; | |||||
| code << " static const int quant_multiplier = " << quant_params_.quant_multiplier_[0] << ";\n"; | |||||
| if (target_ == kARM64) { | |||||
| code.CodeFunction("MatmulInt8Neon64", "cur_a_ptr", "cur_b", "cur_c", params_->row_4_, cur_oc * C4NUM, | |||||
| params_->deep_16_, input_sums_, "cur_bias", INT8_MIN, INT8_MAX, quant_params_.output_.zp_, | |||||
| "&quant_multiplier", "&left_shift", "&right_shift", params_->row_, cur_oc_res, params_->col_, | |||||
| false); | |||||
| } else { | |||||
| code.CodeFunction("MatMulInt8_16x4_r", "cur_a_ptr", "cur_b", "cur_c", params_->row_, cur_oc_res, params_->deep_16_, | |||||
| params_->col_, input_sums_, "cur_bias", "&left_shift", "&right_shift", "&quant_multiplier", | |||||
| quant_params_.output_.zp_, INT8_MIN, INT8_MAX, false); | |||||
| } | |||||
| code << "}\n"; | |||||
| MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called"; | |||||
| context->AppendCode(code.str()); | |||||
| MS_CHECK_RET_CODE(MatMulBaseInt8Coder::DoCode(context), "matmul int8 do code failed"); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_MatMul, CPUOpCoderCreator<MatMulInt8Coder>) | |||||
| } // namespace mindspore::lite::micro::nnacl | } // namespace mindspore::lite::micro::nnacl | ||||
| @@ -16,42 +16,26 @@ | |||||
| #ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ | #ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ | ||||
| #define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ | #define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "coder/opcoders/op_coder.h" | #include "coder/opcoders/op_coder.h" | ||||
| #include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" | |||||
| #include "nnacl/matmul_parameter.h" | #include "nnacl/matmul_parameter.h" | ||||
| namespace mindspore::lite::micro::nnacl { | namespace mindspore::lite::micro::nnacl { | ||||
| class MatMulInt8Coder final : public OperatorCoder { | |||||
| class MatMulInt8Coder final : public MatMulBaseInt8Coder { | |||||
| public: | public: | ||||
| MatMulInt8Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | MatMulInt8Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | ||||
| const Model::Node *node, size_t node_index, Target target) | const Model::Node *node, size_t node_index, Target target) | ||||
| : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} | |||||
| ~MatMulInt8Coder() override; | |||||
| : MatMulBaseInt8Coder(in_tensors, out_tensors, node, node_index, target) {} | |||||
| ~MatMulInt8Coder() override = default; | |||||
| int Prepare(CoderContext *const context) override; | int Prepare(CoderContext *const context) override; | ||||
| int DoCode(CoderContext *const context) override; | int DoCode(CoderContext *const context) override; | ||||
| private: | private: | ||||
| int Init(); | |||||
| int ReSize(CoderContext *const context); | |||||
| private: | |||||
| Tensor *filter_tensor_{nullptr}; | |||||
| Tensor *bias_tensor_{nullptr}; | |||||
| MatMulParameter *params_{nullptr}; | |||||
| MatmulQuantParameter quant_params_{0}; | |||||
| size_t a_pack_ptr_size_{0}; | |||||
| int8_t *a_pack_ptr_{nullptr}; | |||||
| size_t b_pack_batch_ptr_size_{0}; | |||||
| int8_t *b_pack_batch_ptr_{nullptr}; | |||||
| size_t bias_prt_size_{0}; | |||||
| int *bias_ptr_{nullptr}; | |||||
| size_t input_sums_size_{0}; | |||||
| int *input_sums_{nullptr}; | |||||
| size_t weight_bias_sums_batch_size_{0}; | |||||
| int *weight_bias_sums_batch_{nullptr}; | |||||
| int thread_stride_{0}; | |||||
| int thread_count_{0}; | |||||
| int ReSize(CoderContext *const context) override; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::micro::nnacl | } // namespace mindspore::lite::micro::nnacl | ||||
| #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ | #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_INT8_MATMUL_INT8_CODER_H_ | ||||
| @@ -198,11 +198,16 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ReshapeQuant | |||||
| reshape_quant_arg.output_activation_min_, reshape_quant_arg.output_activation_max_); | reshape_quant_arg.output_activation_min_, reshape_quant_arg.output_activation_max_); | ||||
| } | } | ||||
| void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg) { | |||||
| void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg, | |||||
| int weight_quant_num) { | |||||
| CodeArray("filter_scale", matmul_quant_arg.filter_scale_, weight_quant_num); | |||||
| CodeArray("filter_zp", matmul_quant_arg.filter_zp_, weight_quant_num); | |||||
| CodeArray("left_shift", matmul_quant_arg.left_shift_, weight_quant_num); | |||||
| CodeArray("right_shift", matmul_quant_arg.right_shift_, weight_quant_num); | |||||
| CodeArray("multiplier", matmul_quant_arg.quant_multiplier_, weight_quant_num); | |||||
| CodeBaseStruct("MatmulQuantParameter", name, matmul_quant_arg.input_, matmul_quant_arg.weight_, | CodeBaseStruct("MatmulQuantParameter", name, matmul_quant_arg.input_, matmul_quant_arg.weight_, | ||||
| matmul_quant_arg.output_, matmul_quant_arg.out_act_min_, matmul_quant_arg.out_act_max_, | |||||
| matmul_quant_arg.left_shift_[0], matmul_quant_arg.right_shift_[0], | |||||
| matmul_quant_arg.quant_multiplier_[0]); | |||||
| matmul_quant_arg.output_, matmul_quant_arg.out_act_min_, matmul_quant_arg.out_act_max_, "filter_scale", | |||||
| "filter_zp", "left_shift", "right_shift", "multiplier"); | |||||
| } | } | ||||
| void NNaclInt8Serializer::CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg) { | void NNaclInt8Serializer::CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg) { | ||||
| @@ -52,7 +52,7 @@ class NNaclInt8Serializer : public Serializer { | |||||
| void CodeStruct(const std::string &name, const ::QuantMulArg &quant_mul_arg); | void CodeStruct(const std::string &name, const ::QuantMulArg &quant_mul_arg); | ||||
| void CodeStruct(const std::string &name, const ReduceQuantArg &reduce_quant_arg); | void CodeStruct(const std::string &name, const ReduceQuantArg &reduce_quant_arg); | ||||
| void CodeStruct(const std::string &name, const ReshapeQuantArg &reshape_quant_arg); | void CodeStruct(const std::string &name, const ReshapeQuantArg &reshape_quant_arg); | ||||
| void CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg); | |||||
| void CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg, int weight_quant_num); | |||||
| void CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg); | void CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg); | ||||
| void CodeStruct(const std::string &name, const DivQuantArg &div_quant_arg); | void CodeStruct(const std::string &name, const DivQuantArg &div_quant_arg); | ||||
| void CodeStruct(const std::string &name, const ReluXQuantArg &relu_quant_arg); | void CodeStruct(const std::string &name, const ReluXQuantArg &relu_quant_arg); | ||||
| @@ -30,15 +30,16 @@ void InitInt8MatrixA(int8_t *src_ptr, int32_t *input_sums, int8_t *dst_ptr, int | |||||
| } | } | ||||
| } | } | ||||
| void InitInt8MatrixB(int8_t *src_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, int col, | |||||
| int col_4, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, bool b_transpose) { | |||||
| void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, | |||||
| int col, int col_align, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, | |||||
| bool b_transpose, bool filter_per_channel) { | |||||
| for (int i = 0; i < batch; ++i) { | for (int i = 0; i < batch; ++i) { | ||||
| int8_t *cur_b = src_ptr + i * deep * col; | |||||
| int8_t *cur_b_pack = dst_ptr + i * col_4 * deep_16; | |||||
| int32_t *cur_sums = weight_bias_sums_batch_ + i * col_4; | |||||
| int8_t *cur_b = weight_ptr + i * deep * col; | |||||
| int8_t *cur_b_pack = dst_ptr + i * col_align * deep_16; | |||||
| int32_t *cur_sums = weight_bias_sums_batch_ + i * col_align; | |||||
| if (b_transpose) { | if (b_transpose) { | ||||
| RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, col, deep); | RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, col, deep); | ||||
| CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, ColMajor, false); | |||||
| CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, ColMajor, filter_per_channel); | |||||
| } else { | } else { | ||||
| RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack); | RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack); | ||||
| CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false); | CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false); | ||||
| @@ -25,8 +25,9 @@ extern "C" { | |||||
| void InitInt8MatrixA(int8_t *src_ptr, int32_t *input_sums, int8_t *dst_ptr, int batch, int row, int deep, int input_zp, | void InitInt8MatrixA(int8_t *src_ptr, int32_t *input_sums, int8_t *dst_ptr, int batch, int row, int deep, int input_zp, | ||||
| const int *weight_zp, bool a_transpose); | const int *weight_zp, bool a_transpose); | ||||
| void InitInt8MatrixB(int8_t *src_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, int col, | |||||
| int col_4, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, bool b_transpose); | |||||
| void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_t *dst_ptr, int batch, int deep, | |||||
| int col, int col_align, int deep_16, int input_zp, int *weight_zp, const int *bias_ptr, | |||||
| bool b_transpose, bool filter_per_channel); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||