From: @zhujingxuan Reviewed-by: @wangchengyuan Signed-off-by: @wangchengyuanpull/13974/MERGE
| @@ -103,8 +103,8 @@ int ConcatInt8Coder::DoCode(CoderContext *const context) { | |||
| } | |||
| code.CodeStruct("concat_param", *concat_param_, in_tensor_count, input_tensor_->shape().size(), | |||
| output_tensor_->shape().size()); | |||
| code.CodeBaseStruct("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_, | |||
| before_axis_size, count_unit_); | |||
| code.CodeBaseStruct<false>("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_, | |||
| before_axis_size, count_unit_); | |||
| if (support_parallel_) { | |||
| code.CodeFunction(kParallelLaunch, gThreadPool, "ConcatInt8Run", kRunArgsAddr, gThreadNum); | |||
| } else { | |||
| @@ -21,6 +21,7 @@ | |||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||
| #include "coder/opcoders/file_collector.h" | |||
| #include "coder/log.h" | |||
| #include "coder/opcoders/parallel.h" | |||
| #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" | |||
| namespace mindspore::lite::micro::nnacl { | |||
| @@ -43,8 +44,9 @@ int Conv2D1x1Int8Coder::Prepare(CoderContext *const context) { | |||
| int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) { | |||
| Collect(context, | |||
| {"nnacl/int8/conv1x1_int8.h", "nnacl/common_func.h", "wrapper/int8/conv1x1_init_int8_wrapper.h", | |||
| "wrapper/int8/conv1x1_run_int8_wrapper.h"}, | |||
| {"wrapper/int8/conv1x1_init_int8_wrapper.h", "wrapper/int8/conv1x1_run_int8_wrapper.h", "nnacl/common_func.h", | |||
| "nnacl/base/conv1x1_base.h", "nnacl/int8/matmul_int8.h", "nnacl/int8/pack_int8.h", | |||
| "nnacl/int8/conv1x1_int8.h", "nnacl/errorcode.h"}, | |||
| {"common_func.c", "pack_int8.c", "conv1x1_int8.c", "matmul_int8.c", "fixed_point.c", | |||
| "conv1x1_init_int8_wrapper.c", "conv1x1_run_int8_wrapper.c", "conv1x1_base.c"}, | |||
| {"MatmulInt8Opt.S"}); | |||
| @@ -54,11 +56,69 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) { | |||
| code.CodeStruct("conv_param", *conv_param_); | |||
| code.CodeStruct("matmul_param", *matmul_param_); | |||
| code.CodeBaseStruct("Conv1x1Args", "args", input_sum_, filter_zp_ptr_, left_shift_, right_shift_, multiplier_, | |||
| packed_weight_, bias_data_, packed_input_, nullptr, nullptr, 0, 0, "&conv_param", "&matmul_param", | |||
| matmul_func_, pre_trans_input_, support_optimize_, filter_peroc_); | |||
| code.CodeFunction("Conv1x1Run", input_tensor_, "(Conv1x1Args *)&args", output_tensor_); | |||
| code.CodeBaseStruct<false>("Conv1x1Args", kRunArgs, input_sum_, filter_zp_ptr_, left_shift_, right_shift_, | |||
| multiplier_, packed_weight_, bias_data_, packed_input_, nullptr, nullptr, 0, 0, 0, 0, | |||
| "&conv_param", "&matmul_param", matmul_func_, pre_trans_input_, "GetSupportOptFlag()", | |||
| filter_peroc_, false); | |||
| code.CodeFunction("Conv1x1PreRun", kRunArgsAddr, gThreadNum); | |||
| code << "for (int batch_index = 0; batch_index < " << conv_param_->input_batch_ << "; batch_index++) {\n"; | |||
| std::string src_in = allocator_->GetRuntimeAddr(input_tensor_) + " + batch_index * " + | |||
| std::to_string(conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_); | |||
| std::string src_out = allocator_->GetRuntimeAddr(output_tensor_) + " + batch_index * " + | |||
| std::to_string(matmul_param_->row_ * matmul_param_->col_); | |||
| code.CodeFunction("Pre1x1Trans", kRunArgsAddr, src_in, src_out); | |||
| code << "if (args.parallel_by_oc_) {\n"; | |||
| /* input transpose and input sum */ | |||
| code << "if (GetSupportOptFlag()) {\n"; | |||
| if (support_parallel_) { | |||
| code.CodeFunction(kParallelLaunch, gThreadPool, "OcOptPre", kRunArgsAddr, "args.thread_count_hw"); | |||
| } else { | |||
| code.CodeFunction("OcOptPre", kRunArgsAddr, kDefaultTaskId); | |||
| } | |||
| code << "} else {\n"; | |||
| code << "RowMajor2Row16x4MajorInt8(args.input_ptr_, args.packed_input_, args.matmul_param_->row_, " | |||
| "args.matmul_param_->deep_);\n"; | |||
| if (filter_peroc_) { | |||
| code << "PackInputSum16x4PerLayer(args.packed_input_, args.input_sum_, 1, args.matmul_param_->row_4_, " | |||
| "args.matmul_param_->deep_16_);\n"; | |||
| } else { | |||
| code << "PackInputSum16x4PerLayer(args.packed_input_, " | |||
| "args.input_sum_,args.conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, " | |||
| "args.matmul_param_->row_4_, args.matmul_param_->deep_16_);\n"; | |||
| } | |||
| code << "}\n"; | |||
| /* matmul parallel by oc */ | |||
| code << "if (GetSupportOptFlag()) {\n"; | |||
| if (support_parallel_) { | |||
| code.CodeFunction(kParallelLaunch, gThreadPool, "RunArm64OptOc", kRunArgsAddr, "args.thread_count_oc"); | |||
| } else { | |||
| code.CodeFunction("RunArm64OptOc", kRunArgsAddr, kDefaultTaskId); | |||
| } | |||
| code << "} else {\n"; | |||
| if (support_parallel_) { | |||
| code.CodeFunction(kParallelLaunch, gThreadPool, "RunArmOc", kRunArgsAddr, "args.thread_count_oc"); | |||
| } else { | |||
| code.CodeFunction("RunArmOc", kRunArgsAddr, kDefaultTaskId); | |||
| } | |||
| code << "}\n"; | |||
| code << "} else {\n"; | |||
| /* matmul parallel by hw */ | |||
| code << "if (GetSupportOptFlag()) {\n"; | |||
| if (support_parallel_) { | |||
| code.CodeFunction(kParallelLaunch, gThreadPool, "RunArm64OptHw", kRunArgsAddr, "args.thread_count_hw"); | |||
| } else { | |||
| code.CodeFunction("RunArm64OptHw", kRunArgsAddr, kDefaultTaskId); | |||
| } | |||
| code << "} else {\n"; | |||
| if (support_parallel_) { | |||
| code.CodeFunction(kParallelLaunch, gThreadPool, "RunArmHw", kRunArgsAddr, "args.thread_count_hw"); | |||
| } else { | |||
| code.CodeFunction("RunArmHw", kRunArgsAddr, kDefaultTaskId); | |||
| } | |||
| code << "}\n"; | |||
| code << "}\n"; | |||
| code << "}\n"; | |||
| context->AppendCode(code.str()); | |||
| return RET_OK; | |||
| @@ -86,9 +86,13 @@ int ConvolutionDepthwiseINT8Coder::DoCode(CoderContext *const context) { | |||
| {"nnacl/int8/conv_depthwise_int8.h", "nnacl/int8/pack_int8.h", "wrapper/int8/convolution_depthwise_int8_wrapper.h"}, | |||
| {"conv_depthwise_int8.c", "fixed_point.c", "pack_int8.c", "conv_int8.c", "winograd_transform.c", | |||
| "convolution_depthwise_int8_wrapper.c"}, | |||
| {"ConvDwInt8Row.S", "ConvDwInt8PostAlign4.S", "ConvDwInt8PostAlign4PerChannel.S", "ConvDw3x3Int8Stride2.S", | |||
| "ConvDw3x3Int8.S", "ConvDw3x3Int8Vertical.S", "ConvDw3x3Int8Horizontal.S", "ConvDw3x3Int8Corner.S", | |||
| "MatmulOptR4Int8.S", "ConvDwInt8Center.S", "DeconvDwInt8Center.S", "DeconvDwInt8Post.S", "MatmulDpInt8Opt.S"}); | |||
| {"ConvDwInt8Row.S", "ConvDwInt8PostAlign4.S", "ConvDwInt8PostAlign4PerChannel.S", "ConvDwInt8Center.S", | |||
| "DeconvDwInt8Center.S", "DeconvDwInt8Post.S"}); | |||
| if (target_ == kARM64) { | |||
| Collect(context, {}, {}, | |||
| {"ConvDw3x3Int8.S", "ConvDw3x3Int8Corner.S", "ConvDw3x3Int8Horizontal.S", "ConvDw3x3Int8Stride2.S", | |||
| "ConvDw3x3Int8Vertical.S", "MatmulDpInt8Opt.S", "MatmulOptR4Int8.S"}); | |||
| } | |||
| nnacl::NNaclInt8Serializer code; | |||
| code.precision(kPrecision); | |||
| // call the op function | |||
| @@ -90,7 +90,7 @@ int ResizeInt8Coder::DoCode(CoderContext *const context) { | |||
| code.CodeStruct("quant_in", *quant_in_); | |||
| code.CodeStruct("quant_out", *quant_out_); | |||
| code.CodeStruct("multiplier", *multiplier_); | |||
| code.CodeFunction("ResizeNearestNeighborInt8", input_tensor_, output_tensor_, "&input_shape", "&output_shape", | |||
| code.CodeFunction("ResizeNearestNeighborInt8", input_tensor_, output_tensor_, "input_shape", "output_shape", | |||
| align_corners, "multiplier", "quant_in", "quant_out", kDefaultTaskId, gThreadNum); | |||
| } | |||
| break; | |||
| @@ -173,7 +173,7 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ConcatParame | |||
| CodeArray(output_shapes_name, concat_parameter.output_shapes_, out_shape, false); | |||
| CodeBaseStruct<false>("ConcatParameter", name, concat_parameter.op_parameter_, quant_arg_name, concat_parameter.axis_, | |||
| concat_parameter.thread_count_, concat_parameter.input_num_, input_shapes_name, | |||
| concat_parameter.thread_count_, concat_parameter.input_num_, "(int **)" + input_shapes_name, | |||
| output_shapes_name, concat_parameter.after_axis_size, concat_parameter.count_unit_); | |||
| } | |||
| @@ -164,7 +164,7 @@ int RunArmHw(void *cdata, int task_id) { | |||
| return NNACL_OK; | |||
| } | |||
| void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out) { | |||
| void Conv1x1PreRun(Conv1x1Args *args, int thread_num) { | |||
| int row_pack_count = C4NUM; | |||
| int col_pack_count; | |||
| @@ -177,49 +177,16 @@ void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out) { | |||
| col_pack_count = C4NUM; | |||
| } | |||
| #endif | |||
| int thread_num = 1; | |||
| int hw_thread_count = UP_DIV(args->matmul_param_->row_, row_pack_count); | |||
| int oc_thread_count = UP_DIV(args->matmul_param_->col_, col_pack_count); | |||
| size_t thread_count_hw = MSMIN(thread_num, hw_thread_count); | |||
| args->thread_stride_hw_ = UP_DIV(hw_thread_count, thread_count_hw); | |||
| size_t thread_count_oc = MSMIN(thread_num, oc_thread_count); | |||
| args->thread_stride_oc_ = UP_DIV(oc_thread_count, thread_count_oc); | |||
| bool parallel_by_oc = oc_thread_count > thread_num; | |||
| for (int batch_index = 0; batch_index < args->conv_param_->input_batch_; batch_index++) { | |||
| Pre1x1Trans(args, | |||
| src_in + batch_index * args->conv_param_->input_h_ * args->conv_param_->input_w_ * | |||
| args->conv_param_->input_channel_, | |||
| src_out + batch_index * args->matmul_param_->row_ * args->matmul_param_->col_); | |||
| if (parallel_by_oc) { | |||
| /* input transpose and input sum */ | |||
| if (args->support_optimize_) { | |||
| OcOptPre(args, 0); | |||
| } else { | |||
| RowMajor2Row16x4MajorInt8(args->input_ptr_, args->packed_input_, args->matmul_param_->row_, | |||
| args->matmul_param_->deep_); | |||
| if (args->filter_peroc_) { | |||
| PackInputSum16x4PerLayer(args->packed_input_, args->input_sum_, 1, args->matmul_param_->row_4_, | |||
| args->matmul_param_->deep_16_); | |||
| } else { | |||
| PackInputSum16x4PerLayer(args->packed_input_, args->input_sum_, | |||
| args->conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, | |||
| args->matmul_param_->row_4_, args->matmul_param_->deep_16_); | |||
| } | |||
| } | |||
| /* matmul parallel by oc */ | |||
| if (args->support_optimize_) { | |||
| RunArm64OptOc(args, 0); | |||
| } else { | |||
| RunArmOc(args, 0); | |||
| } | |||
| } else { | |||
| /* matmul parallel by hw */ | |||
| if (args->support_optimize_) { | |||
| RunArm64OptHw(args, 0); | |||
| } else { | |||
| RunArmHw(args, 0); | |||
| } | |||
| } | |||
| args->thread_count_hw = MSMIN(thread_num, hw_thread_count); | |||
| args->thread_stride_hw_ = UP_DIV(hw_thread_count, args->thread_count_hw); | |||
| args->thread_count_oc = MSMIN(thread_num, oc_thread_count); | |||
| args->thread_stride_oc_ = UP_DIV(oc_thread_count, args->thread_count_oc); | |||
| args->parallel_by_oc_ = oc_thread_count > thread_num; | |||
| if (!args->filter_peroc_) { | |||
| args->right_shift_ = args->conv_param_->conv_quant_arg_.right_shift_; | |||
| args->left_shift_ = args->conv_param_->conv_quant_arg_.left_shift_; | |||
| args->multiplier_ = args->conv_param_->conv_quant_arg_.quant_multiplier_; | |||
| } | |||
| } | |||
| @@ -33,7 +33,9 @@ typedef struct { | |||
| int8_t *packed_input_; | |||
| int8_t *input_ptr_; | |||
| int8_t *output_ptr_; | |||
| size_t thread_count_hw; | |||
| size_t thread_stride_hw_; | |||
| size_t thread_count_oc; | |||
| size_t thread_stride_oc_; | |||
| ConvParameter *conv_param_; | |||
| MatMulParameter *matmul_param_; | |||
| @@ -41,8 +43,15 @@ typedef struct { | |||
| bool pre_trans_input_; | |||
| bool support_optimize_; | |||
| bool filter_peroc_; | |||
| bool parallel_by_oc_; | |||
| } Conv1x1Args; | |||
| void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out); | |||
| void Conv1x1PreRun(Conv1x1Args *args, int thread_num); | |||
| void Pre1x1Trans(Conv1x1Args *args, int8_t *src_input, int8_t *src_output); | |||
| int OcOptPre(void *cdata, int task_id); | |||
| int RunArm64OptOc(void *cdata, int task_id); | |||
| int RunArmOc(void *cdata, int task_id); | |||
| int RunArm64OptHw(void *cdata, int task_id); | |||
| int RunArmHw(void *cdata, int task_id); | |||
| #endif // MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_INT8_CONV1X1_RUN_H_ | |||