Browse Source

!13974 fix conv1x1

From: @zhujingxuan
Reviewed-by: @wangchengyuan
Signed-off-by: @wangchengyuan
pull/13974/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
ba0b4c9e35
7 changed files with 98 additions and 58 deletions
  1. +2
    -2
      mindspore/lite/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc
  2. +67
    -7
      mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc
  3. +7
    -3
      mindspore/lite/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc
  4. +1
    -1
      mindspore/lite/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc
  5. +1
    -1
      mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc
  6. +10
    -43
      mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.c
  7. +10
    -1
      mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.h

+ 2
- 2
mindspore/lite/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc View File

@@ -103,8 +103,8 @@ int ConcatInt8Coder::DoCode(CoderContext *const context) {
}
code.CodeStruct("concat_param", *concat_param_, in_tensor_count, input_tensor_->shape().size(),
output_tensor_->shape().size());
code.CodeBaseStruct("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_,
before_axis_size, count_unit_);
code.CodeBaseStruct<false>("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_,
before_axis_size, count_unit_);
if (support_parallel_) {
code.CodeFunction(kParallelLaunch, gThreadPool, "ConcatInt8Run", kRunArgsAddr, gThreadNum);
} else {


+ 67
- 7
mindspore/lite/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc View File

@@ -21,6 +21,7 @@
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "coder/opcoders/file_collector.h"
#include "coder/log.h"
#include "coder/opcoders/parallel.h"
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h"

namespace mindspore::lite::micro::nnacl {
@@ -43,8 +44,9 @@ int Conv2D1x1Int8Coder::Prepare(CoderContext *const context) {

int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
Collect(context,
{"nnacl/int8/conv1x1_int8.h", "nnacl/common_func.h", "wrapper/int8/conv1x1_init_int8_wrapper.h",
"wrapper/int8/conv1x1_run_int8_wrapper.h"},
{"wrapper/int8/conv1x1_init_int8_wrapper.h", "wrapper/int8/conv1x1_run_int8_wrapper.h", "nnacl/common_func.h",
"nnacl/base/conv1x1_base.h", "nnacl/int8/matmul_int8.h", "nnacl/int8/pack_int8.h",
"nnacl/int8/conv1x1_int8.h", "nnacl/errorcode.h"},
{"common_func.c", "pack_int8.c", "conv1x1_int8.c", "matmul_int8.c", "fixed_point.c",
"conv1x1_init_int8_wrapper.c", "conv1x1_run_int8_wrapper.c", "conv1x1_base.c"},
{"MatmulInt8Opt.S"});
@@ -54,11 +56,69 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
code.CodeStruct("conv_param", *conv_param_);
code.CodeStruct("matmul_param", *matmul_param_);

code.CodeBaseStruct("Conv1x1Args", "args", input_sum_, filter_zp_ptr_, left_shift_, right_shift_, multiplier_,
packed_weight_, bias_data_, packed_input_, nullptr, nullptr, 0, 0, "&conv_param", "&matmul_param",
matmul_func_, pre_trans_input_, support_optimize_, filter_peroc_);

code.CodeFunction("Conv1x1Run", input_tensor_, "(Conv1x1Args *)&args", output_tensor_);
code.CodeBaseStruct<false>("Conv1x1Args", kRunArgs, input_sum_, filter_zp_ptr_, left_shift_, right_shift_,
multiplier_, packed_weight_, bias_data_, packed_input_, nullptr, nullptr, 0, 0, 0, 0,
"&conv_param", "&matmul_param", matmul_func_, pre_trans_input_, "GetSupportOptFlag()",
filter_peroc_, false);

code.CodeFunction("Conv1x1PreRun", kRunArgsAddr, gThreadNum);
code << "for (int batch_index = 0; batch_index < " << conv_param_->input_batch_ << "; batch_index++) {\n";
std::string src_in = allocator_->GetRuntimeAddr(input_tensor_) + " + batch_index * " +
std::to_string(conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_);
std::string src_out = allocator_->GetRuntimeAddr(output_tensor_) + " + batch_index * " +
std::to_string(matmul_param_->row_ * matmul_param_->col_);
code.CodeFunction("Pre1x1Trans", kRunArgsAddr, src_in, src_out);
code << "if (args.parallel_by_oc_) {\n";
/* input transpose and input sum */
code << "if (GetSupportOptFlag()) {\n";
if (support_parallel_) {
code.CodeFunction(kParallelLaunch, gThreadPool, "OcOptPre", kRunArgsAddr, "args.thread_count_hw");
} else {
code.CodeFunction("OcOptPre", kRunArgsAddr, kDefaultTaskId);
}
code << "} else {\n";
code << "RowMajor2Row16x4MajorInt8(args.input_ptr_, args.packed_input_, args.matmul_param_->row_, "
"args.matmul_param_->deep_);\n";
if (filter_peroc_) {
code << "PackInputSum16x4PerLayer(args.packed_input_, args.input_sum_, 1, args.matmul_param_->row_4_, "
"args.matmul_param_->deep_16_);\n";
} else {
code << "PackInputSum16x4PerLayer(args.packed_input_, "
"args.input_sum_,args.conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, "
"args.matmul_param_->row_4_, args.matmul_param_->deep_16_);\n";
}
code << "}\n";
/* matmul parallel by oc */
code << "if (GetSupportOptFlag()) {\n";
if (support_parallel_) {
code.CodeFunction(kParallelLaunch, gThreadPool, "RunArm64OptOc", kRunArgsAddr, "args.thread_count_oc");
} else {
code.CodeFunction("RunArm64OptOc", kRunArgsAddr, kDefaultTaskId);
}
code << "} else {\n";
if (support_parallel_) {
code.CodeFunction(kParallelLaunch, gThreadPool, "RunArmOc", kRunArgsAddr, "args.thread_count_oc");
} else {
code.CodeFunction("RunArmOc", kRunArgsAddr, kDefaultTaskId);
}
code << "}\n";
code << "} else {\n";
/* matmul parallel by hw */
code << "if (GetSupportOptFlag()) {\n";
if (support_parallel_) {
code.CodeFunction(kParallelLaunch, gThreadPool, "RunArm64OptHw", kRunArgsAddr, "args.thread_count_hw");
} else {
code.CodeFunction("RunArm64OptHw", kRunArgsAddr, kDefaultTaskId);
}
code << "} else {\n";
if (support_parallel_) {
code.CodeFunction(kParallelLaunch, gThreadPool, "RunArmHw", kRunArgsAddr, "args.thread_count_hw");
} else {
code.CodeFunction("RunArmHw", kRunArgsAddr, kDefaultTaskId);
}
code << "}\n";
code << "}\n";
code << "}\n";

context->AppendCode(code.str());
return RET_OK;


+ 7
- 3
mindspore/lite/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc View File

@@ -86,9 +86,13 @@ int ConvolutionDepthwiseINT8Coder::DoCode(CoderContext *const context) {
{"nnacl/int8/conv_depthwise_int8.h", "nnacl/int8/pack_int8.h", "wrapper/int8/convolution_depthwise_int8_wrapper.h"},
{"conv_depthwise_int8.c", "fixed_point.c", "pack_int8.c", "conv_int8.c", "winograd_transform.c",
"convolution_depthwise_int8_wrapper.c"},
{"ConvDwInt8Row.S", "ConvDwInt8PostAlign4.S", "ConvDwInt8PostAlign4PerChannel.S", "ConvDw3x3Int8Stride2.S",
"ConvDw3x3Int8.S", "ConvDw3x3Int8Vertical.S", "ConvDw3x3Int8Horizontal.S", "ConvDw3x3Int8Corner.S",
"MatmulOptR4Int8.S", "ConvDwInt8Center.S", "DeconvDwInt8Center.S", "DeconvDwInt8Post.S", "MatmulDpInt8Opt.S"});
{"ConvDwInt8Row.S", "ConvDwInt8PostAlign4.S", "ConvDwInt8PostAlign4PerChannel.S", "ConvDwInt8Center.S",
"DeconvDwInt8Center.S", "DeconvDwInt8Post.S"});
if (target_ == kARM64) {
Collect(context, {}, {},
{"ConvDw3x3Int8.S", "ConvDw3x3Int8Corner.S", "ConvDw3x3Int8Horizontal.S", "ConvDw3x3Int8Stride2.S",
"ConvDw3x3Int8Vertical.S", "MatmulDpInt8Opt.S", "MatmulOptR4Int8.S"});
}
nnacl::NNaclInt8Serializer code;
code.precision(kPrecision);
// call the op function


+ 1
- 1
mindspore/lite/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc View File

@@ -90,7 +90,7 @@ int ResizeInt8Coder::DoCode(CoderContext *const context) {
code.CodeStruct("quant_in", *quant_in_);
code.CodeStruct("quant_out", *quant_out_);
code.CodeStruct("multiplier", *multiplier_);
code.CodeFunction("ResizeNearestNeighborInt8", input_tensor_, output_tensor_, "&input_shape", "&output_shape",
code.CodeFunction("ResizeNearestNeighborInt8", input_tensor_, output_tensor_, "input_shape", "output_shape",
align_corners, "multiplier", "quant_in", "quant_out", kDefaultTaskId, gThreadNum);
}
break;


+ 1
- 1
mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc View File

@@ -173,7 +173,7 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ConcatParame
CodeArray(output_shapes_name, concat_parameter.output_shapes_, out_shape, false);

CodeBaseStruct<false>("ConcatParameter", name, concat_parameter.op_parameter_, quant_arg_name, concat_parameter.axis_,
concat_parameter.thread_count_, concat_parameter.input_num_, input_shapes_name,
concat_parameter.thread_count_, concat_parameter.input_num_, "(int **)" + input_shapes_name,
output_shapes_name, concat_parameter.after_axis_size, concat_parameter.count_unit_);
}



+ 10
- 43
mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.c View File

@@ -164,7 +164,7 @@ int RunArmHw(void *cdata, int task_id) {
return NNACL_OK;
}

void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out) {
void Conv1x1PreRun(Conv1x1Args *args, int thread_num) {
int row_pack_count = C4NUM;
int col_pack_count;

@@ -177,49 +177,16 @@ void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out) {
col_pack_count = C4NUM;
}
#endif
int thread_num = 1;
int hw_thread_count = UP_DIV(args->matmul_param_->row_, row_pack_count);
int oc_thread_count = UP_DIV(args->matmul_param_->col_, col_pack_count);
size_t thread_count_hw = MSMIN(thread_num, hw_thread_count);
args->thread_stride_hw_ = UP_DIV(hw_thread_count, thread_count_hw);
size_t thread_count_oc = MSMIN(thread_num, oc_thread_count);
args->thread_stride_oc_ = UP_DIV(oc_thread_count, thread_count_oc);
bool parallel_by_oc = oc_thread_count > thread_num;

for (int batch_index = 0; batch_index < args->conv_param_->input_batch_; batch_index++) {
Pre1x1Trans(args,
src_in + batch_index * args->conv_param_->input_h_ * args->conv_param_->input_w_ *
args->conv_param_->input_channel_,
src_out + batch_index * args->matmul_param_->row_ * args->matmul_param_->col_);
if (parallel_by_oc) {
/* input transpose and input sum */
if (args->support_optimize_) {
OcOptPre(args, 0);
} else {
RowMajor2Row16x4MajorInt8(args->input_ptr_, args->packed_input_, args->matmul_param_->row_,
args->matmul_param_->deep_);
if (args->filter_peroc_) {
PackInputSum16x4PerLayer(args->packed_input_, args->input_sum_, 1, args->matmul_param_->row_4_,
args->matmul_param_->deep_16_);
} else {
PackInputSum16x4PerLayer(args->packed_input_, args->input_sum_,
args->conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
args->matmul_param_->row_4_, args->matmul_param_->deep_16_);
}
}
/* matmul parallel by oc */
if (args->support_optimize_) {
RunArm64OptOc(args, 0);
} else {
RunArmOc(args, 0);
}
} else {
/* matmul parallel by hw */
if (args->support_optimize_) {
RunArm64OptHw(args, 0);
} else {
RunArmHw(args, 0);
}
}
args->thread_count_hw = MSMIN(thread_num, hw_thread_count);
args->thread_stride_hw_ = UP_DIV(hw_thread_count, args->thread_count_hw);
args->thread_count_oc = MSMIN(thread_num, oc_thread_count);
args->thread_stride_oc_ = UP_DIV(oc_thread_count, args->thread_count_oc);
args->parallel_by_oc_ = oc_thread_count > thread_num;
if (!args->filter_peroc_) {
args->right_shift_ = args->conv_param_->conv_quant_arg_.right_shift_;
args->left_shift_ = args->conv_param_->conv_quant_arg_.left_shift_;
args->multiplier_ = args->conv_param_->conv_quant_arg_.quant_multiplier_;
}
}

+ 10
- 1
mindspore/lite/micro/coder/operator_library/wrapper/int8/conv1x1_run_int8_wrapper.h View File

@@ -33,7 +33,9 @@ typedef struct {
int8_t *packed_input_;
int8_t *input_ptr_;
int8_t *output_ptr_;
size_t thread_count_hw;
size_t thread_stride_hw_;
size_t thread_count_oc;
size_t thread_stride_oc_;
ConvParameter *conv_param_;
MatMulParameter *matmul_param_;
@@ -41,8 +43,15 @@ typedef struct {
bool pre_trans_input_;
bool support_optimize_;
bool filter_peroc_;
bool parallel_by_oc_;
} Conv1x1Args;

void Conv1x1Run(int8_t *src_in, Conv1x1Args *args, int8_t *src_out);
void Conv1x1PreRun(Conv1x1Args *args, int thread_num);
void Pre1x1Trans(Conv1x1Args *args, int8_t *src_input, int8_t *src_output);
int OcOptPre(void *cdata, int task_id);
int RunArm64OptOc(void *cdata, int task_id);
int RunArmOc(void *cdata, int task_id);
int RunArm64OptHw(void *cdata, int task_id);
int RunArmHw(void *cdata, int task_id);

#endif // MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_INT8_CONV1X1_RUN_H_

Loading…
Cancel
Save