diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc index 17a131fe18..fa9f0b9f5c 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc @@ -103,8 +103,8 @@ int ConcatInt8Coder::DoCode(CoderContext *const context) { } code.CodeStruct("concat_param", *concat_param_, in_tensor_count, input_tensor_->shape().size(), output_tensor_->shape().size()); - code.CodeBaseStruct("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_, - before_axis_size, count_unit_); + code.CodeBaseStruct("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_, + before_axis_size, count_unit_); if (support_parallel_) { code.CodeFunction(kParallelLaunch, gThreadPool, "ConcatInt8Run", kRunArgsAddr, gThreadNum); } else { diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc index 1ec484cd42..df81cf033e 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc @@ -21,6 +21,7 @@ #include "src/runtime/kernel/arm/base/convolution_base.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" +#include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" namespace mindspore::lite::micro::nnacl { @@ -43,8 +44,9 @@ int Conv2D1x1Int8Coder::Prepare(CoderContext *const context) { int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) { Collect(context, - {"nnacl/int8/conv1x1_int8.h", "nnacl/common_func.h", "wrapper/int8/conv1x1_init_int8_wrapper.h", - "wrapper/int8/conv1x1_run_int8_wrapper.h"}, + {"wrapper/int8/conv1x1_init_int8_wrapper.h", "wrapper/int8/conv1x1_run_int8_wrapper.h", "nnacl/common_func.h", + "nnacl/base/conv1x1_base.h", "nnacl/int8/matmul_int8.h", "nnacl/int8/pack_int8.h", + "nnacl/int8/conv1x1_int8.h", "nnacl/errorcode.h"}, {"common_func.c", "pack_int8.c", "conv1x1_int8.c", "matmul_int8.c", "fixed_point.c", "conv1x1_init_int8_wrapper.c", "conv1x1_run_int8_wrapper.c", "conv1x1_base.c"}, {"MatmulInt8Opt.S"}); @@ -54,11 +56,69 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) { code.CodeStruct("conv_param", *conv_param_); code.CodeStruct("matmul_param", *matmul_param_); - code.CodeBaseStruct("Conv1x1Args", "args", input_sum_, filter_zp_ptr_, left_shift_, right_shift_, multiplier_, - packed_weight_, bias_data_, packed_input_, nullptr, nullptr, 0, 0, "&conv_param", "&matmul_param", - matmul_func_, pre_trans_input_, support_optimize_, filter_peroc_); - - code.CodeFunction("Conv1x1Run", input_tensor_, "(Conv1x1Args *)&args", output_tensor_); + code.CodeBaseStruct("Conv1x1Args", kRunArgs, input_sum_, filter_zp_ptr_, left_shift_, right_shift_, + multiplier_, packed_weight_, bias_data_, packed_input_, nullptr, nullptr, 0, 0, 0, 0, + "&conv_param", "&matmul_param", matmul_func_, pre_trans_input_, "GetSupportOptFlag()", + filter_peroc_, false); + + code.CodeFunction("Conv1x1PreRun", kRunArgsAddr, gThreadNum); + code << "for (int batch_index = 0; batch_index < " << conv_param_->input_batch_ << "; batch_index++) {\n"; + std::string src_in = allocator_->GetRuntimeAddr(input_tensor_) + " + batch_index * " + + std::to_string(conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_); + std::string src_out = allocator_->GetRuntimeAddr(output_tensor_) + " + batch_index * " + + std::to_string(matmul_param_->row_ * matmul_param_->col_); + code.CodeFunction("Pre1x1Trans", kRunArgsAddr, src_in, src_out); + code << "if (args.parallel_by_oc_) {\n"; + /* input transpose and input sum */ + code << "if (GetSupportOptFlag()) {\n"; + if (support_parallel_) { + code.CodeFunction(kParallelLaunch, gThreadPool, "OcOptPre", kRunArgsAddr, "args.thread_count_hw"); + } else { + code.CodeFunction("OcOptPre", kRunArgsAddr, kDefaultTaskId); + } + code << "} else {\n"; + code << "RowMajor2Row16x4MajorInt8(args.input_ptr_, args.packed_input_, args.matmul_param_->row_, " + "args.matmul_param_->deep_);\n"; + if (filter_peroc_) { + code << "PackInputSum16x4PerLayer(args.packed_input_, args.input_sum_, 1, args.matmul_param_->row_4_, " + "args.matmul_param_->deep_16_);\n"; + } else { + code << "PackInputSum16x4PerLayer(args.packed_input_, " + "args.input_sum_,args.conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, " + "args.matmul_param_->row_4_, args.matmul_param_->deep_16_);\n"; + } + code << "}\n"; + /* matmul parallel by oc */ + code << "if (GetSupportOptFlag()) {\n"; + if (support_parallel_) { + code.CodeFunction(kParallelLaunch, gThreadPool, "RunArm64OptOc", kRunArgsAddr, "args.thread_count_oc"); + } else { + code.CodeFunction("RunArm64OptOc", kRunArgsAddr, kDefaultTaskId); + } + code << "} else {\n"; + if (support_parallel_) { + code.CodeFunction(kParallelLaunch, gThreadPool, "RunArmOc", kRunArgsAddr, "args.thread_count_oc"); + } else { + code.CodeFunction("RunArmOc", kRunArgsAddr, kDefaultTaskId); + } + code << "}\n"; + code << "} else {\n"; + /* matmul parallel by hw */ + code << "if (GetSupportOptFlag()) {\n"; + if (support_parallel_) { + code.CodeFunction(kParallelLaunch, gThreadPool, "RunArm64OptHw", kRunArgsAddr, "args.thread_count_hw"); + } else { + code.CodeFunction("RunArm64OptHw", kRunArgsAddr, kDefaultTaskId); + } + code << "} else {\n"; + if (support_parallel_) { + code.CodeFunction(kParallelLaunch, gThreadPool, "RunArmHw", kRunArgsAddr, "args.thread_count_hw"); + } else { + code.CodeFunction("RunArmHw", kRunArgsAddr, kDefaultTaskId); + } + code << "}\n"; + code << "}\n"; + code << "}\n"; context->AppendCode(code.str()); return RET_OK; diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc index da6c0778fd..97806a6cbd 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc @@ -86,9 +86,13 @@ int ConvolutionDepthwiseINT8Coder::DoCode(CoderContext *const context) { {"nnacl/int8/conv_depthwise_int8.h", "nnacl/int8/pack_int8.h", "wrapper/int8/convolution_depthwise_int8_wrapper.h"}, {"conv_depthwise_int8.c", "fixed_point.c", "pack_int8.c", "conv_int8.c", "winograd_transform.c", "convolution_depthwise_int8_wrapper.c"}, - {"ConvDwInt8Row.S", "ConvDwInt8PostAlign4.S", "ConvDwInt8PostAlign4PerChannel.S", "ConvDw3x3Int8Stride2.S", - "ConvDw3x3Int8.S", "ConvDw3x3Int8Vertical.S", "ConvDw3x3Int8Horizontal.S", "ConvDw3x3Int8Corner.S", - "MatmulOptR4Int8.S", "ConvDwInt8Center.S", "DeconvDwInt8Center.S", "DeconvDwInt8Post.S", "MatmulDpInt8Opt.S"}); + {"ConvDwInt8Row.S", "ConvDwInt8PostAlign4.S", "ConvDwInt8PostAlign4PerChannel.S", "ConvDwInt8Center.S", + "DeconvDwInt8Center.S", "DeconvDwInt8Post.S"}); + if (target_ == kARM64) { + Collect(context, {}, {}, + {"ConvDw3x3Int8.S", "ConvDw3x3Int8Corner.S", "ConvDw3x3Int8Horizontal.S", "ConvDw3x3Int8Stride2.S", + "ConvDw3x3Int8Vertical.S", "MatmulDpInt8Opt.S", "MatmulOptR4Int8.S"}); + } nnacl::NNaclInt8Serializer code; code.precision(kPrecision); // call the op function diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc index a02c4b3f15..f4a6e88cb3 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc @@ -90,7 +90,7 @@ int ResizeInt8Coder::DoCode(CoderContext *const context) { code.CodeStruct("quant_in", *quant_in_); code.CodeStruct("quant_out", *quant_out_); code.CodeStruct("multiplier", *multiplier_); - code.CodeFunction("ResizeNearestNeighborInt8", input_tensor_, output_tensor_, "&input_shape", "&output_shape", + code.CodeFunction("ResizeNearestNeighborInt8", input_tensor_, output_tensor_, "input_shape", "output_shape", align_corners, "multiplier", "quant_in", "quant_out", kDefaultTaskId, gThreadNum); } break; diff --git a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc index 84fd17b15c..4a0d92b280 100644 --- a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc +++ b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc @@ -173,7 +173,7 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ConcatParame CodeArray(output_shapes_name, concat_parameter.output_shapes_, out_shape, false); CodeBaseStruct("ConcatParameter", name, concat_parameter.op_parameter_, quant_arg_name, concat_parameter.axis_, - concat_parameter.thread_count_, concat_parameter.input_num_, input_shapes_name, + concat_parameter.thread_count_, concat_parameter.input_num_, "(int **)" + input_shapes_name, output_shapes_name, concat_parameter.after_axis_size, concat_parameter.count_unit_); } diff --git a/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.c b/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.c index 4560c4d49f..dec9f70591 100644 --- a/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.c +++ b/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.c @@ -164,7 +164,7 @@ int RunArmHw(void *cdata, int task_id) { return NNACL_OK; } -void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out) { +void Conv1x1PreRun(Conv1x1Args *args, int thread_num) { int row_pack_count = C4NUM; int col_pack_count; @@ -177,49 +177,16 @@ void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out) { col_pack_count = C4NUM; } #endif - int thread_num = 1; int hw_thread_count = UP_DIV(args->matmul_param_->row_, row_pack_count); int oc_thread_count = UP_DIV(args->matmul_param_->col_, col_pack_count); - size_t thread_count_hw = MSMIN(thread_num, hw_thread_count); - args->thread_stride_hw_ = UP_DIV(hw_thread_count, thread_count_hw); - size_t thread_count_oc = MSMIN(thread_num, oc_thread_count); - args->thread_stride_oc_ = UP_DIV(oc_thread_count, thread_count_oc); - bool parallel_by_oc = oc_thread_count > thread_num; - - for (int batch_index = 0; batch_index < args->conv_param_->input_batch_; batch_index++) { - Pre1x1Trans(args, - src_in + batch_index * args->conv_param_->input_h_ * args->conv_param_->input_w_ * - args->conv_param_->input_channel_, - src_out + batch_index * args->matmul_param_->row_ * args->matmul_param_->col_); - if (parallel_by_oc) { - /* input transpose and input sum */ - if (args->support_optimize_) { - OcOptPre(args, 0); - } else { - RowMajor2Row16x4MajorInt8(args->input_ptr_, args->packed_input_, args->matmul_param_->row_, - args->matmul_param_->deep_); - if (args->filter_peroc_) { - PackInputSum16x4PerLayer(args->packed_input_, args->input_sum_, 1, args->matmul_param_->row_4_, - args->matmul_param_->deep_16_); - } else { - PackInputSum16x4PerLayer(args->packed_input_, args->input_sum_, - args->conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, - args->matmul_param_->row_4_, args->matmul_param_->deep_16_); - } - } - /* matmul parallel by oc */ - if (args->support_optimize_) { - RunArm64OptOc(args, 0); - } else { - RunArmOc(args, 0); - } - } else { - /* matmul parallel by hw */ - if (args->support_optimize_) { - RunArm64OptHw(args, 0); - } else { - RunArmHw(args, 0); - } - } + args->thread_count_hw = MSMIN(thread_num, hw_thread_count); + args->thread_stride_hw_ = UP_DIV(hw_thread_count, args->thread_count_hw); + args->thread_count_oc = MSMIN(thread_num, oc_thread_count); + args->thread_stride_oc_ = UP_DIV(oc_thread_count, args->thread_count_oc); + args->parallel_by_oc_ = oc_thread_count > thread_num; + if (!args->filter_peroc_) { + args->right_shift_ = args->conv_param_->conv_quant_arg_.right_shift_; + args->left_shift_ = args->conv_param_->conv_quant_arg_.left_shift_; + args->multiplier_ = args->conv_param_->conv_quant_arg_.quant_multiplier_; } } diff --git a/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.h b/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.h index 310bec30dd..6432e8b91d 100644 --- a/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.h +++ b/mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.h @@ -33,7 +33,9 @@ typedef struct { int8_t *packed_input_; int8_t *input_ptr_; int8_t *output_ptr_; + size_t thread_count_hw; size_t thread_stride_hw_; + size_t thread_count_oc; size_t thread_stride_oc_; ConvParameter *conv_param_; MatMulParameter *matmul_param_; @@ -41,8 +43,15 @@ typedef struct { bool pre_trans_input_; bool support_optimize_; bool filter_peroc_; + bool parallel_by_oc_; } Conv1x1Args; -void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out); +void Conv1x1PreRun(Conv1x1Args *args, int thread_num); +void Pre1x1Trans(Conv1x1Args *args, int8_t *src_input, int8_t *src_output); +int OcOptPre(void *cdata, int task_id); +int RunArm64OptOc(void *cdata, int task_id); +int RunArmOc(void *cdata, int task_id); +int RunArm64OptHw(void *cdata, int task_id); +int RunArmHw(void *cdata, int task_id); #endif // MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_INT8_CONV1X1_RUN_H_