Browse Source

replace 8x8 block with 12x8 block in common conv

tags/v1.1.0
fuzhiye 5 years ago
parent
commit
2b056b7a28
14 changed files with 69 additions and 1675 deletions
  1. +24
    -270
      mindspore/lite/nnacl/fp32/conv.c
  2. +1
    -28
      mindspore/lite/nnacl/fp32/conv.h
  3. +19
    -70
      mindspore/lite/nnacl/pack.c
  4. +0
    -3
      mindspore/lite/nnacl/pack.h
  5. +0
    -605
      mindspore/lite/nnacl/winograd_transform.c
  6. +0
    -15
      mindspore/lite/nnacl/winograd_transform.h
  7. +5
    -0
      mindspore/lite/nnacl/winograd_utils.c
  8. +15
    -52
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
  9. +5
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h
  10. +0
    -245
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc
  11. +0
    -80
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h
  12. +0
    -201
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc
  13. +0
    -70
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.h
  14. +0
    -29
      mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc

+ 24
- 270
mindspore/lite/nnacl/fp32/conv.c View File

@@ -20,192 +20,9 @@
#include "nnacl/winograd_transform.h"
#include "nnacl/fp32/matmul.h"

void SWBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic4, bool is_relu, bool is_relu6) {
for (int c = 0; c < C4NUM; c++) {
dst[c] = 0;
}
const float *weight_oc = weight;
for (int oc = 0; oc < C4NUM; ++oc) {
const float *weight_kh = weight_oc;
const float *src_kh = src;
for (int kh = 0; kh < height; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < width; kw++) {
const float *src_ic4 = src_kw;
const float *weight_ic4 = weight_kw;
for (int ic = 0; ic < ic4; ++ic) {
for (int c = 0; c < C4NUM; c++) {
dst[oc] += src_ic4[c] * weight_ic4[c];
}
src_ic4 += C4NUM;
weight_ic4 += C4NUM;
} // ic4 loop
src_kw += in_kw_step;
weight_kw += ic4 * C4NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * ic4 * C4NUM;
} // kernel_h loop
dst[oc] += bias[oc];
dst[oc] = (is_relu) ? (MSMAX(0, dst[oc])) : (dst[oc]);
dst[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[oc]))) : (dst[oc]);
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
} // oc loop
}

void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
int ic4 = sliding->ic4_channel_ / C4NUM;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float *src_h = src + ih * sliding->in_h_step_;

float *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float *src_w = src_h + iw * sliding->ic4_channel_;

const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;

SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4, relu,
relu6);

dst_kernel += sliding->block_channel_;
} // width loop
dst_h += sliding->out_h_step_;
} // height loop
}

#ifndef ENABLE_ARM64
void SWCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, int kernel_h,
int kernel_w, int out_h_step, int block_channel, int ic4, int in_sh_step, int in_sw_step, int in_kh_step,
int in_kw_step, bool is_relu, bool is_relu6) {
float *dst_h = dst;
const float *src_h = src;
for (int oh = 0; oh < height; oh++) {
float *dst_w = dst_h;
const float *src_w = src_h;
for (int ow = 0; ow < width; ow++) {
const float *weight_oc = weight;
for (int c = 0; c < C4NUM; c++) {
dst_w[c] = 0;
}

for (int oc = 0; oc < C4NUM; oc++) {
const float *weight_kh = weight_oc;
const float *src_kh = src_w;
for (int kh = 0; kh < kernel_h; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
const float *src_ic4 = src_kw;
const float *weight_ic4 = weight_kw;
for (int ic = 0; ic < ic4; ++ic) {
for (int c = 0; c < C4NUM; c++) {
dst_w[oc] += src_ic4[c] * weight_ic4[c];
}

src_ic4 += C4NUM;
weight_ic4 += C4NUM;
} // ic4 loop
src_kw += in_kw_step;
weight_kw += ic4 * C4NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * ic4 * C4NUM;
} // kernel_h loop
// add biad relu

dst_w[oc] += bias[oc];
dst_w[oc] = (is_relu) ? (MSMAX(0, dst_w[oc])) : (dst_w[oc]);
dst_w[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[oc]))) : (dst_w[oc]);
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
} // oc block

dst_w += block_channel;
src_w += in_sw_step;
} // dst_width loop
dst_h += out_h_step;
src_h += in_sh_step;
} // dst_height loop
}
#endif

// fp32 sliding window
void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *tmp_out_block,
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param) {
int ic4 = slidingWindow_param->ic4_channel_ / C4NUM;
int oc4_res = conv_param->output_channel_ % C4NUM;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const float *src = input_data;
float *dst = NULL;
if (oc4_res == 0) {
dst = output_data;
} else {
dst = tmp_out_block;
}

for (int b = 0; b < conv_param->output_batch_; b++) {
for (int oc = task_id; oc < slidingWindow_param->c_block_; oc += conv_param->thread_num_) {
const float *src_data = src;
float *dst_data = dst + oc * C4NUM;
const float *weight = packed_weight + oc * slidingWindow_param->kernel_step_;
const float *bias = bias_data + oc * C4NUM;
SWBorder(dst_data, src_data, weight, bias, 0, slidingWindow_param->top_, 0, conv_param->output_w_, conv_param,
slidingWindow_param);
SWBorder(dst_data, src_data, weight, bias, slidingWindow_param->bottom_, conv_param->output_h_, 0,
conv_param->output_w_, conv_param, slidingWindow_param);
SWBorder(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, 0,
slidingWindow_param->left_, conv_param, slidingWindow_param);
SWBorder(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_,
slidingWindow_param->right_, conv_param->output_w_, conv_param, slidingWindow_param);

if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float *in_t =
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
float *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
#ifdef ENABLE_ARM64
ConvSwFp32Center(
out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
slidingWindow_param->out_h_step_ * sizeof(float), slidingWindow_param->block_channel_ * sizeof(float), ic4,
slidingWindow_param->in_sh_step_ * sizeof(float), slidingWindow_param->in_sw_step_ * sizeof(float),
slidingWindow_param->in_kh_step_ * sizeof(float), slidingWindow_param->in_kw_step_ * sizeof(float), relu,
relu6);
#else
SWCenter(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
slidingWindow_param->in_sh_step_, slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
slidingWindow_param->in_kw_step_, relu, relu6);
#endif
}
} // output C4 loop
src += slidingWindow_param->in_step_;
dst += slidingWindow_param->out_step_;
} // batch loop
}

// fp32 conv common
void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data,
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
GEMM_FUNC_FP32 gemm_func) {
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) {
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_batch = conv_param->input_batch_;
@@ -217,42 +34,38 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
int out_channel = conv_param->output_channel_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
#ifdef ENABLE_ARM32
const int cal_num = C4NUM;
#else
const int cal_num = C12NUM;
#endif
int output_tile_count = UP_DIV(output_count, cal_num);
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;

// we accumulate 4 channels per time for input blocks
int conv_depth = kernel_h * kernel_w;
// bytes from one output's i-th channel to the next output's i-th channel
// we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward
size_t output_offset = out_channel * sizeof(float);
int unit_size = kernel_plane * in_channel;
int deep = in_channel * kernel_plane;

for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * in_channel * in_h * in_w;
int out_batch_offset = b * out_channel * out_h * out_w;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
float *gemm_input = packed_input + task_id * unit_size * TILE_NUM;
size_t packed_input_size = unit_size * TILE_NUM * sizeof(float);
int start_index = thread_id * cal_num;
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
float *gemm_input = packed_input + task_id * unit_size * cal_num;
float *col_major_gemm_input = col_major_input + task_id * unit_size * cal_num;
size_t packed_input_size = unit_size * cal_num * sizeof(float);
memset(gemm_input, 0, packed_input_size);
memset(col_major_gemm_input, 0, packed_input_size);
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);

int out_offset = thread_id * TILE_NUM * out_channel + out_batch_offset;
if (real_cal_num == TILE_NUM) {
float *gemm_output = output_data + out_offset;
gemm_func(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
relu, relu6);
} else {
// res part
float *tmp_out_ptr = tmp_out_block + task_id * TILE_NUM * out_channel;
gemm_func(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
relu, relu6);
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float));
}
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
float *gemm_output = output_data + out_offset;
#ifdef ENABLE_ARM32
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
#else
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);
#endif
MatMulOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num,
out_channel, out_channel, OutType_Nhwc);
}
}
}
@@ -321,62 +134,3 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
}
}
}

// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param) {
int thread_count = conv_param->thread_num_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
#ifdef ENABLE_ARM32
const int tile_num = 4;
#else
const int tile_num = 12;
#endif
int output_tile_count = UP_DIV(output_count, tile_num);
const int input_unit_square = 4 * 4;

float *tile_buffer = buffer_list[0];
float *block_unit_buffer = buffer_list[1];
float *tmp_dst_buffer = buffer_list[2];
float *nc4hw4_out = buffer_list[3];
float *col_buffer = buffer_list[4];
int tile_buffer_offset = tile_num * input_unit_square * ic4 * C4NUM;
int block_unit_buffer_offset = input_unit_square * C4NUM;
int tmp_dst_buffer_offset = tile_num * input_unit_square * oc8 * C8NUM;
int col_buffer_offset = tile_num * ic4 * C4NUM;

int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
int in_batch_offset = batch * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;

for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * tile_num;
int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num;
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);

float *src_ptr = tile_buffer + task_id * tile_buffer_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
#endif
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
}
Conv3x3Fp32OutputTransform(dst_ptr, nc4hw4_out + nc4hw4_buffer_offset, bias_data, start_index, real_cal_num,
out_w_block, conv_param);
}
}
}

+ 1
- 28
mindspore/lite/nnacl/fp32/conv.h View File

@@ -28,46 +28,19 @@
#include "nnacl/fp32/conv_depthwise.h"

typedef float *TmpBufferAddress;
typedef float *Matrices;
typedef void (*GEMM_FUNC_FP32)(float *output, const float *input, const float *weight, const float *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4,
size_t relu, size_t relu6);

#ifdef __cplusplus
extern "C" {
#endif
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding);

void SWCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, int kernel_h,
int kernel_w, int out_h_step, int block_channel, int ic4, int in_sh_step, int in_sw_step, int in_kh_step,
int in_kw_step, bool is_relu, bool is_relu6);

// fp32 sliding window
void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *tmp_out_block,
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param);

// fp32 convolution common (im2col+gemm)
void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data,
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
GEMM_FUNC_FP32 gemm_func);
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param);

// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func,
OutputTransFunc out_func);

void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);

void UnPackWinogradReluOutput(const float *src, float *dst, int batch, int height, int width, int channel,
int output_unit);

void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int height, int width, int channel,
int output_unit);

// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif


+ 19
- 70
mindspore/lite/nnacl/pack.c View File

@@ -18,50 +18,6 @@
#include <string.h>
#include <stdlib.h>

void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block,
int oc_block_num) {
// original weight format : ohwi
if (oc_block_num == 0) {
return;
}
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc_block * oc_block_num * ic4 * C4NUM * kernel_plane;

int unit_size = oc_block * C4NUM;
const int block_size = pack_weight_size / oc_block_num;

for (int m = 0; m < kernel_plane; m++) {
int kernel_plane_stride = m * in_channel;
int packed_kernel_plane_stride = m * unit_size * ic4;
for (int i = 0; i < ic4; i++) {
int channel_block_stride = kernel_plane_stride + i * C4NUM;
int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size;
int ic_remainder = in_channel - i * C4NUM;
int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM;
for (int h = 0; h < real_ic_num; h++) {
int block_stride = channel_block_stride + h;
int packed_block_stride = packed_channel_block_size + h * oc_block;
for (int j = 0; j < oc_block_num; j++) {
int kernel_block_stride = block_stride + j * oc_block * kernel_plane * in_channel;
int packed_kernel_block_size = packed_block_stride + j * block_size;
int oc_remainder = out_channel - j * oc_block;
int real_oc_num = oc_remainder < oc_block ? oc_remainder : oc_block;
for (int k = 0; k < real_oc_num; k++) {
float *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
float *packed_data_ptr = packed_weight + packed_kernel_block_size + k;
*packed_data_ptr = *origin_data_ptr;
}
} // kernel block loop
} // inchannel block loop
} // channel block loop
} // kernel plane loop
}

void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
}
@@ -301,6 +257,7 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int kernel_plane = kernel_h * kernel_w;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_u_;
@@ -311,8 +268,6 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int ic4_minus = in_channel / C4NUM;
int ic4 = UP_DIV(in_channel, C4NUM);

for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
@@ -323,31 +278,25 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4_minus; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * C8NUM * C4NUM;
#ifdef ENABLE_NEON
vst1q_f32(packed_input + channel_block_offset, vld1q_f32(input_data + channel_block_stride));
#else
for (int k = 0; k < C4NUM; ++k) {
(packed_input + channel_block_offset)[k] = (input_data + channel_block_stride)[k];
}
#endif
} // channel_block loop
int ic_res = conv_param->input_channel_ - ic4_minus * C4NUM;
for (int l = 0; l < ic_res; ++l) {
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l;
int channel_block_offset = input_plane_offset + ic4_minus * C8NUM * C4NUM + l;
packed_input[channel_block_offset] = input_data[channel_block_stride];
if (dilation_w == 1 && dilation_h == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
(kw_e - kw_s) * in_channel * sizeof(float));
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int k = kw_s; k < kw_e; ++k) {
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
}
} // kernel_w loop
} // kernel_h loop
} // tile num loop
} // kernel_h loop
}
} // tile num loop
}

void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,


+ 0
- 3
mindspore/lite/nnacl/pack.h View File

@@ -51,9 +51,6 @@ void MatrixPack(const float *src, float *dst, int row, int ic4, int stride);

void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);

void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block,
int oc_block_num);

void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);

void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);


+ 0
- 605
mindspore/lite/nnacl/winograd_transform.c View File

@@ -142,611 +142,6 @@ void WinogradOutputTransform(const float *gemm_out, float *out_data, const float
}
}

// fp32 conv3x3
void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step) {
#ifdef ENABLE_ARM
float32x4_t d00 = vld1q_f32(tmp_data);
float32x4_t d01 = vld1q_f32(tmp_data + 4);
float32x4_t d02 = vld1q_f32(tmp_data + 2 * 4);
float32x4_t d03 = vld1q_f32(tmp_data + 3 * 4);

float32x4_t d10 = vld1q_f32(tmp_data + 4 * 4);
float32x4_t d11 = vld1q_f32(tmp_data + 5 * 4);
float32x4_t d12 = vld1q_f32(tmp_data + 6 * 4);
float32x4_t d13 = vld1q_f32(tmp_data + 7 * 4);

float32x4_t d20 = vld1q_f32(tmp_data + 8 * 4);
float32x4_t d21 = vld1q_f32(tmp_data + 9 * 4);
float32x4_t d22 = vld1q_f32(tmp_data + 10 * 4);
float32x4_t d23 = vld1q_f32(tmp_data + 11 * 4);

float32x4_t d30 = vld1q_f32(tmp_data + 12 * 4);
float32x4_t d31 = vld1q_f32(tmp_data + 13 * 4);
float32x4_t d32 = vld1q_f32(tmp_data + 14 * 4);
float32x4_t d33 = vld1q_f32(tmp_data + 15 * 4);

float32x4_t t00 = vsubq_f32(d00, d20);
float32x4_t t01 = vsubq_f32(d01, d21);
float32x4_t t02 = vsubq_f32(d02, d22);
float32x4_t t03 = vsubq_f32(d03, d23);

float32x4_t t10 = vaddq_f32(d10, d20);
float32x4_t t11 = vaddq_f32(d11, d21);
float32x4_t t12 = vaddq_f32(d12, d22);
float32x4_t t13 = vaddq_f32(d13, d23);

float32x4_t t20 = vsubq_f32(d20, d10);
float32x4_t t21 = vsubq_f32(d21, d11);
float32x4_t t22 = vsubq_f32(d22, d12);
float32x4_t t23 = vsubq_f32(d23, d13);

float32x4_t t30 = vsubq_f32(d10, d30);
float32x4_t t31 = vsubq_f32(d11, d31);
float32x4_t t32 = vsubq_f32(d12, d32);
float32x4_t t33 = vsubq_f32(d13, d33);

float32x4_t m00 = vsubq_f32(t00, t02);
float32x4_t m01 = vaddq_f32(t01, t02);
float32x4_t m02 = vsubq_f32(t02, t01);
float32x4_t m03 = vsubq_f32(t01, t03);

float32x4_t m10 = vsubq_f32(t10, t12);
float32x4_t m11 = vaddq_f32(t11, t12);
float32x4_t m12 = vsubq_f32(t12, t11);
float32x4_t m13 = vsubq_f32(t11, t13);

float32x4_t m20 = vsubq_f32(t20, t22);
float32x4_t m21 = vaddq_f32(t21, t22);
float32x4_t m22 = vsubq_f32(t22, t21);
float32x4_t m23 = vsubq_f32(t21, t23);

float32x4_t m30 = vsubq_f32(t30, t32);
float32x4_t m31 = vaddq_f32(t31, t32);
float32x4_t m32 = vsubq_f32(t32, t31);
float32x4_t m33 = vsubq_f32(t31, t33);

vst1q_f32(trans_input_data, m00);
vst1q_f32(trans_input_data + step, m01);
vst1q_f32(trans_input_data + 2 * step, m02);
vst1q_f32(trans_input_data + 3 * step, m03);

vst1q_f32(trans_input_data + 4 * step, m10);
vst1q_f32(trans_input_data + 5 * step, m11);
vst1q_f32(trans_input_data + 6 * step, m12);
vst1q_f32(trans_input_data + 7 * step, m13);

vst1q_f32(trans_input_data + 8 * step, m20);
vst1q_f32(trans_input_data + 9 * step, m21);
vst1q_f32(trans_input_data + 10 * step, m22);
vst1q_f32(trans_input_data + 11 * step, m23);

vst1q_f32(trans_input_data + 12 * step, m30);
vst1q_f32(trans_input_data + 13 * step, m31);
vst1q_f32(trans_input_data + 14 * step, m32);
vst1q_f32(trans_input_data + 15 * step, m33);
#else
for (int i = 0; i < C4NUM; i++) {
const float *local_ptr = tmp_data + i;
float d00 = local_ptr[0];
float d01 = (local_ptr + C4NUM)[0];
float d02 = (local_ptr + 2 * C4NUM)[0];
float d03 = (local_ptr + 3 * C4NUM)[0];

float d10 = (local_ptr + 4 * C4NUM)[0];
float d11 = (local_ptr + 5 * C4NUM)[0];
float d12 = (local_ptr + 6 * C4NUM)[0];
float d13 = (local_ptr + 7 * C4NUM)[0];

float d20 = (local_ptr + 8 * C4NUM)[0];
float d21 = (local_ptr + 9 * C4NUM)[0];
float d22 = (local_ptr + 10 * C4NUM)[0];
float d23 = (local_ptr + 11 * C4NUM)[0];

float d30 = (local_ptr + 12 * C4NUM)[0];
float d31 = (local_ptr + 13 * C4NUM)[0];
float d32 = (local_ptr + 14 * C4NUM)[0];
float d33 = (local_ptr + 15 * C4NUM)[0];

float t00 = d00 - d20;
float t01 = d01 - d21;
float t02 = d02 - d22;
float t03 = d03 - d23;

float t10 = d10 + d20;
float t11 = d11 + d21;
float t12 = d12 + d22;
float t13 = d13 + d23;

float t20 = d20 - d10;
float t21 = d21 - d11;
float t22 = d22 - d12;
float t23 = d23 - d13;

float t30 = d10 - d30;
float t31 = d11 - d31;
float t32 = d12 - d32;
float t33 = d13 - d33;

float m00 = t00 - t02;
float m01 = t01 + t02;
float m02 = t02 - t01;
float m03 = t01 - t03;

float m10 = t10 - t12;
float m11 = t11 + t12;
float m12 = t12 - t11;
float m13 = t11 - t13;

float m20 = t20 - t22;
float m21 = t21 + t22;
float m22 = t22 - t21;
float m23 = t21 - t23;

float m30 = t30 - t32;
float m31 = t31 + t32;
float m32 = t32 - t31;
float m33 = t31 - t33;

(trans_input_data + i)[0] = m00;
(trans_input_data + i + step)[0] = m01;
(trans_input_data + i + 2 * step)[0] = m02;
(trans_input_data + i + 3 * step)[0] = m03;

(trans_input_data + i + 4 * step)[0] = m10;
(trans_input_data + i + 5 * step)[0] = m11;
(trans_input_data + i + 6 * step)[0] = m12;
(trans_input_data + i + 7 * step)[0] = m13;

(trans_input_data + i + 8 * step)[0] = m20;
(trans_input_data + i + 9 * step)[0] = m21;
(trans_input_data + i + 10 * step)[0] = m22;
(trans_input_data + i + 11 * step)[0] = m23;

(trans_input_data + i + 12 * step)[0] = m30;
(trans_input_data + i + 13 * step)[0] = m31;
(trans_input_data + i + 14 * step)[0] = m32;
(trans_input_data + i + 15 * step)[0] = m33;
}
#endif
}

void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, float *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
// input data format : nhwc
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
int ic4 = UP_DIV(input_channel, C4NUM);
const int input_unit = 4;
if (out_w_block == 0) {
return;
}
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
int x_id = start_index + cal_id;
int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w;
int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h;
int real_x_start = origin_x > 0 ? 0 : -origin_x;
int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x);
int real_y_start = origin_y > 0 ? 0 : -origin_y;
int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y);

int src_plane_offset = input_channel * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C4NUM * ic4;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float));
int real_c = input_channel - ic * C4NUM;
real_c = real_c > C4NUM ? C4NUM : real_c;

// get real input block with padding
int src_ic4_offset = src_plane_offset + ic * C4NUM;
if (real_c == C4NUM) {
for (int interval = real_y_start; interval < real_y_end; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * input_channel;
int dst_y_offset = interval * input_unit * C4NUM + real_x_start * C4NUM;
for (int j = 0; j < (real_x_end - real_x_start); j++) {
int src_x_offset = src_y_offset + j * input_channel;
int dst_x_offset = dst_y_offset + j * C4NUM;
float *src_addr = (float *)(input_data) + src_x_offset;
float *dst_addr = tmp_data + dst_x_offset;
#ifdef ENABLE_NEON
vst1q_f32(dst_addr, vld1q_f32(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
dst_addr[k] = src_addr[k];
}
#endif
}
}
} else {
for (int interval = real_y_start; interval < real_y_end; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * input_channel;
int dst_y_offset = interval * input_unit * C4NUM + real_x_start * C4NUM;
for (int j = 0; j < (real_x_end - real_x_start); j++) {
int src_x_offset = src_y_offset + j * input_channel;
int dst_x_offset = dst_y_offset + j * C4NUM;
float *src_addr = (float *)(input_data) + src_x_offset;
float *dst_addr = tmp_data + dst_x_offset;
for (int k = 0; k < real_c; k++) {
dst_addr[k] = src_addr[k];
}
}
}
}

// input transform
#ifdef ENABLE_ARM32
const int tile_num = 4;
#else
const int tile_num = 12;
#endif
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = tile_num * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step);
}
}
}

void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, int kernel_plane,
int oc_block) {
if (oc_block == 0) {
return;
}
int oc_plane_block = UP_DIV(output_channel, oc_block);
int dst_step = iC4 * C4NUM * oc_block * oc_plane_block;
for (int o = 0; o < output_channel; o++) {
int oc_block_num = o / oc_block;
int oc_block_rem = o % oc_block;
int src_oc_offset = o * iC4 * C4NUM * kernel_plane;
int dst_oc_offset = oc_block_num * oc_block * iC4 * C4NUM + oc_block_rem;
for (int i = 0; i < iC4; i++) {
float *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM;
float *dst_ic4_ptr = trans_weight + dst_oc_offset + i * oc_block * C4NUM;
#ifdef ENABLE_ARM
float32x4_t g00 = vld1q_f32(src_ic4_ptr);
float32x4_t g01 = vld1q_f32(src_ic4_ptr + 4);
float32x4_t g02 = vld1q_f32(src_ic4_ptr + 2 * 4);
float32x4_t g10 = vld1q_f32(src_ic4_ptr + 3 * 4);
float32x4_t g11 = vld1q_f32(src_ic4_ptr + 4 * 4);
float32x4_t g12 = vld1q_f32(src_ic4_ptr + 5 * 4);
float32x4_t g20 = vld1q_f32(src_ic4_ptr + 6 * 4);
float32x4_t g21 = vld1q_f32(src_ic4_ptr + 7 * 4);
float32x4_t g22 = vld1q_f32(src_ic4_ptr + 8 * 4);

float32x4_t dst00 = g00;
float32x4_t dst01 = g01;
float32x4_t dst02 = g02;

float32x4_t dst10 = vaddq_f32(vaddq_f32(g00, g10), g20);
dst10 = vmulq_n_f32(dst10, 0.5);
float32x4_t dst11 = vaddq_f32(vaddq_f32(g01, g11), g21);
dst11 = vmulq_n_f32(dst11, 0.5);
float32x4_t dst12 = vaddq_f32(vaddq_f32(g02, g12), g22);
dst12 = vmulq_n_f32(dst12, 0.5);

float32x4_t dst20 = vaddq_f32(vsubq_f32(g00, g10), g20);
dst20 = vmulq_n_f32(dst20, 0.5);
float32x4_t dst21 = vaddq_f32(vsubq_f32(g01, g11), g21);
dst21 = vmulq_n_f32(dst21, 0.5);
float32x4_t dst22 = vaddq_f32(vsubq_f32(g02, g12), g22);
dst22 = vmulq_n_f32(dst22, 0.5);

float32x4_t dst30 = g20;
float32x4_t dst31 = g21;
float32x4_t dst32 = g22;

float32x4_t m00 = dst00;
float32x4_t m01 = vaddq_f32(vaddq_f32(dst00, dst01), dst02);
m01 = vmulq_n_f32(m01, 0.5);
float32x4_t m02 = vaddq_f32(vsubq_f32(dst00, dst01), dst02);
m02 = vmulq_n_f32(m02, 0.5);
float32x4_t m03 = dst02;

float32x4_t m10 = dst10;
float32x4_t m11 = vaddq_f32(vaddq_f32(dst10, dst11), dst12);
m11 = vmulq_n_f32(m11, 0.5);
float32x4_t m12 = vaddq_f32(vsubq_f32(dst10, dst11), dst12);
m12 = vmulq_n_f32(m12, 0.5);
float32x4_t m13 = dst12;

float32x4_t m20 = dst20;
float32x4_t m21 = vaddq_f32(vaddq_f32(dst20, dst21), dst22);
m21 = vmulq_n_f32(m21, 0.5);
float32x4_t m22 = vaddq_f32(vsubq_f32(dst20, dst21), dst22);
m22 = vmulq_n_f32(m22, 0.5);
float32x4_t m23 = dst22;

float32x4_t m30 = dst30;
float32x4_t m31 = vaddq_f32(vaddq_f32(dst30, dst31), dst32);
m31 = vmulq_n_f32(m31, 0.5);
float32x4_t m32 = vaddq_f32(vsubq_f32(dst30, dst31), dst32);
m32 = vmulq_n_f32(m32, 0.5);
float32x4_t m33 = dst32;

dst_ic4_ptr[0] = m00[0];
dst_ic4_ptr[8] = m00[1];
dst_ic4_ptr[16] = m00[2];
dst_ic4_ptr[24] = m00[3];

dst_ic4_ptr[0 + dst_step] = m01[0];
dst_ic4_ptr[8 + dst_step] = m01[1];
dst_ic4_ptr[16 + dst_step] = m01[2];
dst_ic4_ptr[24 + dst_step] = m01[3];

dst_ic4_ptr[0 + 2 * dst_step] = m02[0];
dst_ic4_ptr[8 + 2 * dst_step] = m02[1];
dst_ic4_ptr[16 + 2 * dst_step] = m02[2];
dst_ic4_ptr[24 + 2 * dst_step] = m02[3];

dst_ic4_ptr[0 + 3 * dst_step] = m03[0];
dst_ic4_ptr[8 + 3 * dst_step] = m03[1];
dst_ic4_ptr[16 + 3 * dst_step] = m03[2];
dst_ic4_ptr[24 + 3 * dst_step] = m03[3];

dst_ic4_ptr[0 + 4 * dst_step] = m10[0];
dst_ic4_ptr[8 + 4 * dst_step] = m10[1];
dst_ic4_ptr[16 + 4 * dst_step] = m10[2];
dst_ic4_ptr[24 + 4 * dst_step] = m10[3];

dst_ic4_ptr[0 + 5 * dst_step] = m11[0];
dst_ic4_ptr[8 + 5 * dst_step] = m11[1];
dst_ic4_ptr[16 + 5 * dst_step] = m11[2];
dst_ic4_ptr[24 + 5 * dst_step] = m11[3];

dst_ic4_ptr[0 + 6 * dst_step] = m12[0];
dst_ic4_ptr[8 + 6 * dst_step] = m12[1];
dst_ic4_ptr[16 + 6 * dst_step] = m12[2];
dst_ic4_ptr[24 + 6 * dst_step] = m12[3];

dst_ic4_ptr[0 + 7 * dst_step] = m13[0];
dst_ic4_ptr[8 + 7 * dst_step] = m13[1];
dst_ic4_ptr[16 + 7 * dst_step] = m13[2];
dst_ic4_ptr[24 + 7 * dst_step] = m13[3];

dst_ic4_ptr[0 + 8 * dst_step] = m20[0];
dst_ic4_ptr[8 + 8 * dst_step] = m20[1];
dst_ic4_ptr[16 + 8 * dst_step] = m20[2];
dst_ic4_ptr[24 + 8 * dst_step] = m20[3];

dst_ic4_ptr[0 + 9 * dst_step] = m21[0];
dst_ic4_ptr[8 + 9 * dst_step] = m21[1];
dst_ic4_ptr[16 + 9 * dst_step] = m21[2];
dst_ic4_ptr[24 + 9 * dst_step] = m21[3];

dst_ic4_ptr[0 + 10 * dst_step] = m22[0];
dst_ic4_ptr[8 + 10 * dst_step] = m22[1];
dst_ic4_ptr[16 + 10 * dst_step] = m22[2];
dst_ic4_ptr[24 + 10 * dst_step] = m22[3];

dst_ic4_ptr[0 + 11 * dst_step] = m23[0];
dst_ic4_ptr[8 + 11 * dst_step] = m23[1];
dst_ic4_ptr[16 + 11 * dst_step] = m23[2];
dst_ic4_ptr[24 + 11 * dst_step] = m23[3];

dst_ic4_ptr[0 + 12 * dst_step] = m30[0];
dst_ic4_ptr[8 + 12 * dst_step] = m30[1];
dst_ic4_ptr[16 + 12 * dst_step] = m30[2];
dst_ic4_ptr[24 + 12 * dst_step] = m30[3];

dst_ic4_ptr[0 + 13 * dst_step] = m31[0];
dst_ic4_ptr[8 + 13 * dst_step] = m31[1];
dst_ic4_ptr[16 + 13 * dst_step] = m31[2];
dst_ic4_ptr[24 + 13 * dst_step] = m31[3];

dst_ic4_ptr[0 + 14 * dst_step] = m32[0];
dst_ic4_ptr[8 + 14 * dst_step] = m32[1];
dst_ic4_ptr[16 + 14 * dst_step] = m32[2];
dst_ic4_ptr[24 + 14 * dst_step] = m32[3];

dst_ic4_ptr[0 + 15 * dst_step] = m33[0];
dst_ic4_ptr[8 + 15 * dst_step] = m33[1];
dst_ic4_ptr[16 + 15 * dst_step] = m33[2];
dst_ic4_ptr[24 + 15 * dst_step] = m33[3];
#else
for (int j = 0; j < C4NUM; j++) {
float *local_ptr = src_ic4_ptr + j;
float dst00 = local_ptr[0];
float dst01 = (local_ptr + 4)[0];
float dst02 = (local_ptr + 8)[0];

const float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
const float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
const float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];

const float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
const float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
const float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];

float dst30 = (local_ptr + 24)[0];
float dst31 = (local_ptr + 28)[0];
float dst32 = (local_ptr + 32)[0];

float m00 = dst00;
const float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02;
const float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02;
float m03 = dst02;

float m10 = dst10;
const float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12;
const float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12;
float m13 = dst12;

float m20 = dst20;
const float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22;
const float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22;
float m23 = dst22;

float m30 = dst30;
const float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32;
const float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32;
float m33 = dst32;

*(dst_ic4_ptr + j * 8) = m00;
*(dst_ic4_ptr + j * 8 + dst_step) = m01;
*(dst_ic4_ptr + j * 8 + 2 * dst_step) = m02;
*(dst_ic4_ptr + j * 8 + 3 * dst_step) = m03;

*(dst_ic4_ptr + j * 8 + 4 * dst_step) = m10;
*(dst_ic4_ptr + j * 8 + 5 * dst_step) = m11;
*(dst_ic4_ptr + j * 8 + 6 * dst_step) = m12;
*(dst_ic4_ptr + j * 8 + 7 * dst_step) = m13;

*(dst_ic4_ptr + j * 8 + 8 * dst_step) = m20;
*(dst_ic4_ptr + j * 8 + 9 * dst_step) = m21;
*(dst_ic4_ptr + j * 8 + 10 * dst_step) = m22;
*(dst_ic4_ptr + j * 8 + 11 * dst_step) = m23;

*(dst_ic4_ptr + j * 8 + 12 * dst_step) = m30;
*(dst_ic4_ptr + j * 8 + 13 * dst_step) = m31;
*(dst_ic4_ptr + j * 8 + 14 * dst_step) = m32;
*(dst_ic4_ptr + j * 8 + 15 * dst_step) = m33;
}
#endif
}
}
}

void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float *output_data, bool h_not_bound,
bool w_not_bound, int output_w) {
#ifdef ENABLE_ARM
float32x4_t bias_ptr = vld1q_f32(bias_data);

float32x4_t s00 = vld1q_f32(gemm_out);
float32x4_t s01 = vld1q_f32(gemm_out + 8);
float32x4_t s02 = vld1q_f32(gemm_out + 16);
float32x4_t s03 = vld1q_f32(gemm_out + 24);

float32x4_t s10 = vld1q_f32(gemm_out + 32);
float32x4_t s11 = vld1q_f32(gemm_out + 40);
float32x4_t s12 = vld1q_f32(gemm_out + 48);
float32x4_t s13 = vld1q_f32(gemm_out + 56);

float32x4_t s20 = vld1q_f32(gemm_out + 64);
float32x4_t s21 = vld1q_f32(gemm_out + 72);
float32x4_t s22 = vld1q_f32(gemm_out + 80);
float32x4_t s23 = vld1q_f32(gemm_out + 88);

float32x4_t s30 = vld1q_f32(gemm_out + 96);
float32x4_t s31 = vld1q_f32(gemm_out + 104);
float32x4_t s32 = vld1q_f32(gemm_out + 112);
float32x4_t s33 = vld1q_f32(gemm_out + 120);

float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20);
float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21);
float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22);
float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23);

float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30);
float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31);
float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32);
float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33);

float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr);
float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr);
float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr);
float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr);

vst1q_f32(output_data, d00);
if (w_not_bound) {
vst1q_f32(output_data + 4, d01);
}
if (h_not_bound) {
vst1q_f32(output_data + output_w * 4, d10);
if (w_not_bound) {
vst1q_f32(output_data + output_w * 4 + 4, d11);
}
}
#else
for (int i = 0; i < C4NUM; i++) {
const float *local_ptr = gemm_out + i;
const float *bias_ptr = bias_data + i;

float s00 = local_ptr[0];
float s01 = (local_ptr + 8)[0];
float s02 = (local_ptr + 16)[0];
float s03 = (local_ptr + 24)[0];

float s10 = (local_ptr + 32)[0];
float s11 = (local_ptr + 40)[0];
float s12 = (local_ptr + 48)[0];
float s13 = (local_ptr + 56)[0];

float s20 = (local_ptr + 64)[0];
float s21 = (local_ptr + 72)[0];
float s22 = (local_ptr + 80)[0];
float s23 = (local_ptr + 88)[0];

float s30 = (local_ptr + 96)[0];
float s31 = (local_ptr + 104)[0];
float s32 = (local_ptr + 112)[0];
float s33 = (local_ptr + 120)[0];

float t00 = s00 + s10 + s20;
float t01 = s01 + s11 + s21;
float t02 = s02 + s12 + s22;
float t03 = s03 + s13 + s23;

float t10 = s10 - s20 - s30;
float t11 = s11 - s21 - s31;
float t12 = s12 - s22 - s32;
float t13 = s13 - s23 - s33;

float d00 = t00 + t01 + t02 + bias_ptr[0];
float d01 = t01 - t02 - t03 + bias_ptr[0];
float d10 = t10 + t11 + t12 + bias_ptr[0];
float d11 = t11 - t12 - t13 + bias_ptr[0];

(output_data + i)[0] = d00;
if (w_not_bound) {
(output_data + i + C4NUM)[0] = d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = d10;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = d11;
}
}
}
#endif
}

void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
const int input_unit = 4;
if (out_w_block == 0) {
return;
}
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w);

for (int j = 0; j < oc4; j++) {
int c8_block = j / 2;
int c8_res = j % 2;
int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM;
int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w;
const float *src_ptr = gemm_out + src_oc4_offset;
const float *bias_ptr = bias_data + j * C4NUM;
float *dst_ptr = out_data + dst_oc4_offset;

// output transform
bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w;
bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h;
Conv3x3Fp32OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w);
}
}
}

// int8 conv3x3
void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) {
#ifdef ENABLE_ARM


+ 0
- 15
mindspore/lite/nnacl/winograd_transform.h View File

@@ -38,21 +38,6 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func);

// for fp32 convolution 3x3 filter/input/output transform
void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step);

void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, float *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param);

void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, int kernel_plane,
int oc_block);

void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float *output_data, bool h_not_bound,
bool w_not_bound, int output_w);

void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param);

// for int8 convolution 3x3 filter/input/output transform
void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp);



+ 5
- 0
mindspore/lite/nnacl/winograd_utils.c View File

@@ -25,6 +25,7 @@ static InputTransFunc InputTransFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit};

static OutputTransFunc OutputTransFuncList4[] = {NULL, NULL, OutputTransform4x2Unit, OutputTransform4x3Unit};

static OutputTransFunc OutputTransFuncReluList4[] = {NULL, NULL, OutputTransform4x2ReluUnit,
OutputTransform4x3ReluUnit};
static OutputTransFunc OutputTransFuncRelu6List4[] = {NULL, NULL, OutputTransform4x2Relu6Unit,
@@ -32,12 +33,14 @@ static OutputTransFunc OutputTransFuncRelu6List4[] = {NULL, NULL, OutputTransfor

static OutputTransFunc OutputTransFuncList6[] = {
NULL, NULL, OutputTransform6x2Unit, OutputTransform6x3Unit, OutputTransform6x4Unit, OutputTransform6x5Unit};

static OutputTransFunc OutputTransFuncReluList6[] = {NULL,
NULL,
OutputTransform6x2ReluUnit,
OutputTransform6x3ReluUnit,
OutputTransform6x4ReluUnit,
OutputTransform6x5ReluUnit};

static OutputTransFunc OutputTransFuncRelu6List6[] = {NULL,
NULL,
OutputTransform6x2Relu6Unit,
@@ -53,6 +56,7 @@ static OutputTransFunc OutputTransFuncList8[] = {NULL,
OutputTransform8x5Unit,
OutputTransform8x6Unit,
OutputTransform8x7Unit};

static OutputTransFunc OutputTransFuncReluList8[] = {NULL,
NULL,
OutputTransform8x2ReluUnit,
@@ -61,6 +65,7 @@ static OutputTransFunc OutputTransFuncReluList8[] = {NULL,
OutputTransform8x5ReluUnit,
OutputTransform8x6ReluUnit,
OutputTransform8x7ReluUnit};

static OutputTransFunc OutputTransFuncRelu6List8[] = {NULL,
NULL,
OutputTransform8x2Relu6Unit,


+ 15
- 52
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc View File

@@ -15,9 +15,7 @@
*/

#include "src/runtime/kernel/arm/fp32/convolution.h"
#include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "src/runtime/kernel/arm/fp32/convolution_3x3.h"
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h"
#include "nnacl/fp32/conv.h"
#include "nnacl/common_func.h"
@@ -42,17 +40,10 @@ int ConvolutionCPUKernel::InitWeightBias() {
int out_channel = filter_tensor->Batch();
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int oc_block, oc_block_num;
#ifdef ENABLE_ARM32
oc_block = C4NUM;
oc_block_num = UP_DIV(out_channel, C4NUM);
#else
oc_block = C8NUM;
oc_block_num = UP_DIV(out_channel, C8NUM);
#endif
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;
const int oc_block = C8NUM;
int oc_block_num = UP_DIV(out_channel, C8NUM);
int pack_weight_size = oc_block_num * oc_block * in_channel * kernel_plane;

auto origin_weight = reinterpret_cast<float *>(filter_tensor->MutableData());
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
@@ -61,7 +52,7 @@ int ConvolutionCPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
PackWeightFp32(origin_weight, conv_param_, packed_weight_, oc_block, oc_block_num);
RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);

bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
if (bias_data_ == nullptr) {
@@ -80,38 +71,28 @@ int ConvolutionCPUKernel::InitWeightBias() {
}

int ConvolutionCPUKernel::InitTmpBuffer() {
int out_channel = conv_param_->output_channel_;
int in_channel = conv_param_->input_channel_;
MS_ASSERT(ctx_->allocator != nullptr);

int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * ic4 * C4NUM * TILE_NUM * thread_count_;
#ifdef ENABLE_ARM32
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * in_channel * C4NUM * thread_count_;
#else
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * in_channel * C12NUM * thread_count_;
#endif
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed input failed.";
return RET_ERROR;
}

tmp_output_block_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * TILE_NUM * out_channel * sizeof(float)));
if (tmp_output_block_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output block failed.";
col_major_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (col_major_input_ == nullptr) {
MS_LOG(ERROR) << "malloc col_major_input_ failed.";
return RET_ERROR;
}
return RET_OK;
}

void ConvolutionCPUKernel::ConfigInputOutput() {
// set output format
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format::Format_NHWC);

#ifdef ENABLE_ARM32
gemm_func_ = IndirectGemmFp32_8x4;
#else
gemm_func_ = IndirectGemmFp32_8x8;
#endif
}

int ConvolutionCPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != RET_OK) {
@@ -121,7 +102,6 @@ int ConvolutionCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
ConfigInputOutput();
return ReSize();
}

@@ -141,15 +121,11 @@ int ConvolutionCPUKernel::ReSize() {
}

int ConvolutionCPUKernel::RunImpl(int task_id) {
if (gemm_func_ == nullptr) {
MS_LOG(ERROR) << "gemm_func is nullptr.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
ConvFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), tmp_output_block_,
output_addr, task_id, conv_param_, gemm_func_);
ConvFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), col_major_input_,
output_addr, task_id, conv_param_);
return RET_OK;
}

@@ -186,19 +162,6 @@ int ConvolutionCPUKernel::Run() {
return RET_OK;
}

bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
int in_channel = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc4 = UP_DIV(out_channel, C4NUM);
if (out_h * out_w <= 32 || ic4 < 4 || oc4 < 4) {
return true;
}
return false;
}

kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const kernel::KernelKey &desc,


+ 5
- 7
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h View File

@@ -43,23 +43,21 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();

private:
void FreeTmpBuffer() {
if (tmp_output_block_ != nullptr) {
ctx_->allocator->Free(tmp_output_block_);
tmp_output_block_ = nullptr;
}
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
if (col_major_input_ != nullptr) {
ctx_->allocator->Free(col_major_input_);
col_major_input_ = nullptr;
}
}
float *packed_input_ = nullptr;
float *packed_weight_ = nullptr;
float *tmp_output_block_ = nullptr;
GEMM_FUNC_FP32 gemm_func_ = nullptr;
float *col_major_input_ = nullptr;
};
} // namespace mindspore::kernel



+ 0
- 245
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc View File

@@ -1,245 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/convolution_3x3.h"
#include "nnacl/fp32/conv.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;

namespace mindspore::kernel {
void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num) {
auto input_channel = conv_param->input_channel_;
auto output_channel = conv_param->output_channel_;
auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
int iC4 = UP_DIV(input_channel, C4NUM);

size_t tmp_size = oc_block_num * oc_block * iC4 * C4NUM * kernel_plane * sizeof(float);
auto tmp_addr = reinterpret_cast<float *>(malloc(tmp_size));
if (tmp_addr == nullptr) {
MS_LOG(ERROR) << "malloc tmp_addr failed.";
return;
}
memset(tmp_addr, 0, tmp_size);

PackNHWCToNC4HW4Fp32(origin_weight, tmp_addr, output_channel, kernel_plane, input_channel);
Conv3x3Fp32FilterTransform(tmp_addr, dst_weight, iC4, output_channel, kernel_plane, oc_block);
free(tmp_addr);
}

int Convolution3x3CPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
conv_param_->input_channel_ = input_channel;
conv_param_->output_channel_ = output_channel;
int iC4 = UP_DIV(input_channel, C4NUM);
int oC4 = UP_DIV(output_channel, C4NUM);
int oc_block, oc_block_num;
oc_block = C8NUM;
oc_block_num = UP_DIV(output_channel, C8NUM);
const int k_plane = 16;
// init weight
size_t transformed_size = iC4 * C4NUM * oc_block_num * oc_block * k_plane * sizeof(float);
transformed_filter_addr_ = reinterpret_cast<float *>(malloc(transformed_size));
if (transformed_filter_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc transformed filter addr failed.";
return RET_ERROR;
}
memset(transformed_filter_addr_, 0, transformed_size);
auto weight_data = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->MutableData());
ProcessFilter(weight_data, transformed_filter_addr_, conv_param_, oc_block, oc_block_num);

// init bias
size_t new_bias_size = oC4 * C4NUM * sizeof(float);
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias data failed.";
return RET_ERROR;
}
memset(bias_data_, 0, new_bias_size);
if (in_tensors_.size() == kInputSize2) {
auto ori_bias_addr = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(float));
} else {
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;
}

int Convolution3x3CPUKernel::InitTmpBuffer() {
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int oC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM);
const int k_plane = 16;
MS_ASSERT(ctx_->allocator != nullptr);

#ifdef ENABLE_ARM32
const int tile_num = 4;
#else
const int tile_num = 12;
#endif

size_t tile_buffer_size = thread_count_ * tile_num * C16NUM * ic4 * C4NUM * sizeof(float);
tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile buffer failed.";
return RET_ERROR;
}

size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float);
block_unit_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(block_unit_buffer_size));
if (block_unit_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc block_unit_buffer_ failed.";
return RET_ERROR;
}

size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float);
tmp_dst_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tmp_dst_buffer_size));
if (tmp_dst_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed.";
return RET_ERROR;
}

size_t col_buffer_size = thread_count_ * tile_num * C4NUM * ic4 * sizeof(float);
col_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(col_buffer_size));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;
}

size_t nc4hw4_out_size =
oC4 * C4NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float);
nc4hw4_out_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(nc4hw4_out_size));
if (nc4hw4_out_ == nullptr) {
MS_LOG(ERROR) << "malloc nc4hw4_out_ failed.";
return RET_ERROR;
}

tmp_buffer_address_list_[0] = tile_buffer_;
tmp_buffer_address_list_[1] = block_unit_buffer_;
tmp_buffer_address_list_[2] = tmp_dst_buffer_;
tmp_buffer_address_list_[3] = nc4hw4_out_;
tmp_buffer_address_list_[4] = col_buffer_;
return RET_OK;
}

int Convolution3x3CPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.ret: " << ret;
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int Convolution3x3CPUKernel::ReSize() {
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize is invalid.";
return ret;
}

ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}

int Convolution3x3CPUKernel::RunImpl(int task_id) {
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
Conv3x3Fp32(ori_input_data, transformed_filter_addr_, reinterpret_cast<float *>(bias_data_), tmp_buffer_address_list_,
task_id, conv_param_);
return RET_OK;
}

int Convolution3x3Impl(void *cdata, int task_id) {
auto conv3x3 = reinterpret_cast<Convolution3x3CPUKernel *>(cdata);
auto error_code = conv3x3->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution3x3 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int Convolution3x3CPUKernel::PostProcess() {
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
auto act_type = conv_param_->act_type_;
switch (act_type) {
case ActType_No:
PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu:
PackNC4HW4ToNHWCReluFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu6:
PackNC4HW4ToNHWCRelu6Fp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
default:
MS_LOG(ERROR) << "Unsupport activation type.";
return RET_ERROR;
}
return RET_OK;
}

int Convolution3x3CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}

auto ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.ret: " << ret;
return RET_ERROR;
}

int error_code = ParallelLaunch(this->context_->thread_pool_, Convolution3x3Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}

ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
FreeTmpBuffer();
return ret;
}
FreeTmpBuffer();
return RET_OK;
}
} // namespace mindspore::kernel

+ 0
- 80
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h View File

@@ -1,80 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_

#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "nnacl/winograd_transform.h"

namespace mindspore::kernel {
class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel {
public:
Convolution3x3CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~Convolution3x3CPUKernel() override {
if (transformed_filter_addr_ != nullptr) {
free(transformed_filter_addr_);
}
}
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
int PostProcess();

private:
void FreeTmpBuffer() {
if (tile_buffer_ != nullptr) {
ctx_->allocator->Free(tile_buffer_);
tile_buffer_ = nullptr;
}
if (block_unit_buffer_ != nullptr) {
ctx_->allocator->Free(block_unit_buffer_);
block_unit_buffer_ = nullptr;
}
if (tmp_dst_buffer_ != nullptr) {
ctx_->allocator->Free(tmp_dst_buffer_);
tmp_dst_buffer_ = nullptr;
}
if (nc4hw4_out_ != nullptr) {
ctx_->allocator->Free(nc4hw4_out_);
nc4hw4_out_ = nullptr;
}
if (col_buffer_ != nullptr) {
ctx_->allocator->Free(col_buffer_);
col_buffer_ = nullptr;
}
}

float *transformed_filter_addr_ = nullptr;
float *tile_buffer_ = nullptr;
float *block_unit_buffer_ = nullptr;
float *tmp_dst_buffer_ = nullptr;
float *col_buffer_ = nullptr;
float *nc4hw4_out_ = nullptr;
TmpBufferAddress tmp_buffer_address_list_[5];
};
void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num);
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_

+ 0
- 201
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc View File

@@ -1,201 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h"
#include "nnacl/common_func.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"

namespace mindspore::kernel {
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;

int ConvolutionSWCPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
int kernel_h = filter_tensor->Height();
int kernel_w = filter_tensor->Width();
conv_param_->input_channel_ = input_channel;
conv_param_->output_channel_ = output_channel;
int ic4 = UP_DIV(input_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int oc_block = C4NUM;
int oc_block_num = UP_DIV(output_channel, C4NUM);
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;

auto origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->MutableData());
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc packed weight failed.";
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
for (int oc = 0; oc < output_channel; ++oc) {
int src_oc_offset = oc * kernel_h * kernel_w * input_channel;
int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM;
for (int i = 0; i < kernel_h * kernel_w; ++i) {
const float *src = origin_weight + src_oc_offset + i * input_channel;
float *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM;
memcpy(dst, src, input_channel * sizeof(float));
}
}

bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
memcpy(bias_data_, ori_bias, output_channel * sizeof(float));
} else {
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;
}

int ConvolutionSWCPUKernel::InitTmpBuffer() {
int out_channel = conv_param_->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
MS_ASSERT(ctx_->allocator != nullptr);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
size_t nhwc4_input_size =
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
return RET_ERROR;
}

tmp_output_block_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(
conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc4 * C4NUM * sizeof(float)));
if (tmp_output_block_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output block failed.";
return RET_ERROR;
}
return RET_OK;
}

void ConvolutionSWCPUKernel::ConfigInputOutput() {
// set output format
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format::Format_NHWC);
}

int ConvolutionSWCPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
// config input output
ConfigInputOutput();
return ReSize();
}

int ConvolutionSWCPUKernel::ReSize() {
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize is invalid.";
return ret;
}

if (slidingWindow_param_ != nullptr) {
delete slidingWindow_param_;
slidingWindow_param_ = nullptr;
}

ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return RET_ERROR;
}

// init sliding window param
slidingWindow_param_ = new (std::nothrow) SlidingWindowParam;
if (slidingWindow_param_ == nullptr) {
MS_LOG(ERROR) << "new SlidingWindowParam fail!";
return RET_ERROR;
}
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);

return RET_OK;
}

int ConvolutionSWCPUKernel::RunImpl(int task_id) {
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
ConvSWFp32(reinterpret_cast<float *>(nhwc4_input_), packed_weight_, reinterpret_cast<float *>(bias_data_),
tmp_output_block_, output_addr, task_id, conv_param_, slidingWindow_param_);
return RET_OK;
}

int ConvolutionSWImpl(void *cdata, int task_id) {
auto conv = reinterpret_cast<ConvolutionSWCPUKernel *>(cdata);
auto error_code = conv->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution Sliding Window Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int ConvolutionSWCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}

// init tmp input, output
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = input_tensor->MutableData();
PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);

int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionSWImpl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv error error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}

auto out_tensor = out_tensors_.front();
auto out_data = reinterpret_cast<float *>(out_tensor->MutableData());
int oc4_res = conv_param_->output_channel_ % C4NUM;
if (oc4_res != 0) {
PackNHWC4ToNHWCFp32(tmp_output_block_, out_data, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
}
FreeTmpBuffer();
return RET_OK;
}
} // namespace mindspore::kernel

+ 0
- 70
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.h View File

@@ -1,70 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_

#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/op_base.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "nnacl/fp32/conv.h"
#include "nnacl/fp32/conv_depthwise.h"

namespace mindspore::kernel {
class ConvolutionSWCPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionSWCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}

~ConvolutionSWCPUKernel() override {
if (packed_weight_ != nullptr) {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (slidingWindow_param_ != nullptr) {
delete slidingWindow_param_;
slidingWindow_param_ = nullptr;
}
}

int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();

private:
void FreeTmpBuffer() {
if (nhwc4_input_ != nullptr) {
ctx_->allocator->Free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
if (tmp_output_block_ != nullptr) {
ctx_->allocator->Free(tmp_output_block_);
tmp_output_block_ = nullptr;
}
}
float *packed_weight_ = nullptr;
float *tmp_output_block_ = nullptr;
SlidingWindowParam *slidingWindow_param_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_

+ 0
- 29
mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc View File

@@ -107,35 +107,6 @@ TEST_F(TestPack, PackInputFp32) {
MS_LOG(INFO) << "TestPackInputFp32 passed";
}

TEST_F(TestPack, PackWeightFp32) {
auto conv_param = new ConvParameter;
InitConvParamPack(conv_param);

int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);

size_t weight_size;
std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin";
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
auto packed_weight = reinterpret_cast<float *>(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float)));
PackWeightFp32(weight_data, conv_param, packed_weight, C8NUM, oc8);

printf("==================output data=================\n");
for (int i = 0; i < 20; i++) {
std::cout << packed_weight[i] << " ,";
}
std::cout << std::endl;

free(packed_weight);
delete conv_param;

MS_LOG(INFO) << "TestPackWeightFp32 passed";
}

#ifdef ENABLE_FP16
TEST_F(TestPack, PackInputFp16) {
size_t input_size;


Loading…
Cancel
Save