Browse Source

!9980 [MS][LITE][Develop]optimization for winograd transorm functions on x86

From: @lx0095
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f098ce614d
5 changed files with 1353 additions and 1307 deletions
  1. +1
    -1
      mindspore/lite/nnacl/assembly/avx/MatmulAvx.S
  2. +24
    -23
      mindspore/lite/nnacl/fp32/pooling_fp32.c
  3. +3
    -2
      mindspore/lite/nnacl/winograd_transform.c
  4. +1148
    -1104
      mindspore/lite/nnacl/winograd_utils.c
  5. +177
    -177
      mindspore/lite/nnacl/winograd_utils.h

+ 1
- 1
mindspore/lite/nnacl/assembly/avx/MatmulAvx.S View File

@@ -840,7 +840,7 @@ LoopRow:
vmovups %ymm12, (%rdx)
addq %r12, %rdx
vmovups %ymm14, (%rdx)
cmpq $-8, %rbx
cmpq $8, %rbx
je WriteEnd
movq %rax, %rdx
addq %r13, %rax


+ 24
- 23
mindspore/lite/nnacl/fp32/pooling_fp32.c View File

@@ -17,6 +17,7 @@
#include "nnacl/fp32/pooling_fp32.h"
#include <float.h>
#include "nnacl/errorcode.h"
#include "nnacl/op_base.h"

int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
float minf, float maxf) {
@@ -32,9 +33,9 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int window = win_w * win_h;

#ifdef ENABLE_NEON
float32x4_t min_value = vdupq_n_f32(minf);
float32x4_t max_value = vdupq_n_f32(maxf);
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
#endif

for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
@@ -61,8 +62,8 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
for (int ci = 0; ci < c4; ci++) {
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
#ifdef ENABLE_NEON
float32x4_t tmp_avg = vdupq_n_f32(0);
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0);
#else
float tmp_avg1 = 0;
float tmp_avg2 = 0;
@@ -73,8 +74,8 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
for (int h = real_win_h_start; h < real_win_h_end; h++) {
for (int w = real_win_w_start; w < real_win_w_end; w++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(src_win_ptr));
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(src_win_ptr));
#else
tmp_avg1 += src_win_ptr[0];
tmp_avg2 += src_win_ptr[1];
@@ -90,11 +91,11 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
if (real_count == 0) {
return NNACL_ERR;
}
#ifdef ENABLE_NEON
tmp_avg = tmp_avg / vdupq_n_f32(real_count);
tmp_avg = vmaxq_f32(tmp_avg, min_value);
tmp_avg = vminq_f32(tmp_avg, max_value);
vst1q_f32(dst_c_ptr, tmp_avg);
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
tmp_avg = tmp_avg / MS_MOVQ_F32(real_count);
tmp_avg = MS_MAXQ_F32(tmp_avg, min_value);
tmp_avg = MS_MINQ_F32(tmp_avg, max_value);
MS_STQ_F32(dst_c_ptr, tmp_avg);
#else
tmp_avg1 /= (float)real_count;
tmp_avg2 /= (float)real_count;
@@ -158,9 +159,9 @@ void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParamete
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int c4 = channel / C4NUM; /* oc && ic */

#ifdef ENABLE_NEON
float32x4_t min_value = vdupq_n_f32(minf);
float32x4_t max_value = vdupq_n_f32(maxf);
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
#endif

for (int batch = 0; batch < output_batch; batch++) {
@@ -187,8 +188,8 @@ void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParamete
for (int ci = 0; ci < c4; ci++) {
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
#ifdef ENABLE_NEON
float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX);
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
MS_FLOAT32X4 tmp_max = MS_MOVQ_F32(-FLT_MAX);
#else
float tmp_max1 = -FLT_MAX;
float tmp_max2 = -FLT_MAX;
@@ -199,8 +200,8 @@ void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParamete
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
tmp_max = MS_MAXQ_F32(tmp_max, MS_LDQ_F32(src_win_ptr));
#else
tmp_max1 = fmax(tmp_max1, src_win_ptr[0]);
tmp_max2 = fmax(tmp_max2, src_win_ptr[1]);
@@ -209,10 +210,10 @@ void MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParamete
#endif
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, min_value);
tmp_max = vminq_f32(tmp_max, max_value);
vst1q_f32(dst_c_ptr, tmp_max);
#if defined(ENABLE_NEON) || defined(ENALBE_SSE)
tmp_max = MS_MAXQ_F32(tmp_max, min_value);
tmp_max = MS_MINQ_F32(tmp_max, max_value);
MS_STQ_F32(dst_c_ptr, tmp_max);
#else
tmp_max1 = fmax(tmp_max1, minf);
tmp_max2 = fmax(tmp_max2, minf);


+ 3
- 2
mindspore/lite/nnacl/winograd_transform.c View File

@@ -15,6 +15,7 @@
*/

#include "nnacl/winograd_transform.h"
#include "nnacl/op_base.h"

// fp32 conv winograd
void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
@@ -61,8 +62,8 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int dst_x_offset = dst_y_offset + j * C4NUM;
float *src_addr = (float *)(input_data) + src_x_offset;
float *dst_addr = tmp_data + dst_x_offset;
#ifdef ENABLE_NEON
vst1q_f32(dst_addr, vld1q_f32(src_addr));
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
dst_addr[k] = src_addr[k];


+ 1148
- 1104
mindspore/lite/nnacl/winograd_utils.c
File diff suppressed because it is too large
View File


+ 177
- 177
mindspore/lite/nnacl/winograd_utils.h View File

@@ -39,127 +39,127 @@ void GeneralInputTransformUnit(const float *src_data, float *dst_data, const flo
void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const float *bias_data, const float *matrix_a,
const float *matrix_at, int src_step, int dst_step, int in_unit, int out_unit);

#define Load16Data \
src[0] = vld1q_f32(src_data + 0 * src_step); \
src[1] = vld1q_f32(src_data + 1 * src_step); \
src[2] = vld1q_f32(src_data + 2 * src_step); \
src[3] = vld1q_f32(src_data + 3 * src_step); \
src[4] = vld1q_f32(src_data + 4 * src_step); \
src[5] = vld1q_f32(src_data + 5 * src_step); \
src[6] = vld1q_f32(src_data + 6 * src_step); \
src[7] = vld1q_f32(src_data + 7 * src_step); \
src[8] = vld1q_f32(src_data + 8 * src_step); \
src[9] = vld1q_f32(src_data + 9 * src_step); \
src[10] = vld1q_f32(src_data + 10 * src_step); \
src[11] = vld1q_f32(src_data + 11 * src_step); \
src[12] = vld1q_f32(src_data + 12 * src_step); \
src[13] = vld1q_f32(src_data + 13 * src_step); \
src[14] = vld1q_f32(src_data + 14 * src_step); \
src[15] = vld1q_f32(src_data + 15 * src_step);
#define Load16Data \
src[0] = MS_LDQ_F32(src_data + 0 * src_step); \
src[1] = MS_LDQ_F32(src_data + 1 * src_step); \
src[2] = MS_LDQ_F32(src_data + 2 * src_step); \
src[3] = MS_LDQ_F32(src_data + 3 * src_step); \
src[4] = MS_LDQ_F32(src_data + 4 * src_step); \
src[5] = MS_LDQ_F32(src_data + 5 * src_step); \
src[6] = MS_LDQ_F32(src_data + 6 * src_step); \
src[7] = MS_LDQ_F32(src_data + 7 * src_step); \
src[8] = MS_LDQ_F32(src_data + 8 * src_step); \
src[9] = MS_LDQ_F32(src_data + 9 * src_step); \
src[10] = MS_LDQ_F32(src_data + 10 * src_step); \
src[11] = MS_LDQ_F32(src_data + 11 * src_step); \
src[12] = MS_LDQ_F32(src_data + 12 * src_step); \
src[13] = MS_LDQ_F32(src_data + 13 * src_step); \
src[14] = MS_LDQ_F32(src_data + 14 * src_step); \
src[15] = MS_LDQ_F32(src_data + 15 * src_step);

#define Load36Data \
src[0] = vld1q_f32(src_data + 0 * src_step); \
src[1] = vld1q_f32(src_data + 1 * src_step); \
src[2] = vld1q_f32(src_data + 2 * src_step); \
src[3] = vld1q_f32(src_data + 3 * src_step); \
src[4] = vld1q_f32(src_data + 4 * src_step); \
src[5] = vld1q_f32(src_data + 5 * src_step); \
src[6] = vld1q_f32(src_data + 6 * src_step); \
src[7] = vld1q_f32(src_data + 7 * src_step); \
src[8] = vld1q_f32(src_data + 8 * src_step); \
src[9] = vld1q_f32(src_data + 9 * src_step); \
src[10] = vld1q_f32(src_data + 10 * src_step); \
src[11] = vld1q_f32(src_data + 11 * src_step); \
src[12] = vld1q_f32(src_data + 12 * src_step); \
src[13] = vld1q_f32(src_data + 13 * src_step); \
src[14] = vld1q_f32(src_data + 14 * src_step); \
src[15] = vld1q_f32(src_data + 15 * src_step); \
src[16] = vld1q_f32(src_data + 16 * src_step); \
src[17] = vld1q_f32(src_data + 17 * src_step); \
src[18] = vld1q_f32(src_data + 18 * src_step); \
src[19] = vld1q_f32(src_data + 19 * src_step); \
src[20] = vld1q_f32(src_data + 20 * src_step); \
src[21] = vld1q_f32(src_data + 21 * src_step); \
src[22] = vld1q_f32(src_data + 22 * src_step); \
src[23] = vld1q_f32(src_data + 23 * src_step); \
src[24] = vld1q_f32(src_data + 24 * src_step); \
src[25] = vld1q_f32(src_data + 25 * src_step); \
src[26] = vld1q_f32(src_data + 26 * src_step); \
src[27] = vld1q_f32(src_data + 27 * src_step); \
src[28] = vld1q_f32(src_data + 28 * src_step); \
src[29] = vld1q_f32(src_data + 29 * src_step); \
src[30] = vld1q_f32(src_data + 30 * src_step); \
src[31] = vld1q_f32(src_data + 31 * src_step); \
src[32] = vld1q_f32(src_data + 32 * src_step); \
src[33] = vld1q_f32(src_data + 33 * src_step); \
src[34] = vld1q_f32(src_data + 34 * src_step); \
src[35] = vld1q_f32(src_data + 35 * src_step);
#define Load36Data \
src[0] = MS_LDQ_F32(src_data + 0 * src_step); \
src[1] = MS_LDQ_F32(src_data + 1 * src_step); \
src[2] = MS_LDQ_F32(src_data + 2 * src_step); \
src[3] = MS_LDQ_F32(src_data + 3 * src_step); \
src[4] = MS_LDQ_F32(src_data + 4 * src_step); \
src[5] = MS_LDQ_F32(src_data + 5 * src_step); \
src[6] = MS_LDQ_F32(src_data + 6 * src_step); \
src[7] = MS_LDQ_F32(src_data + 7 * src_step); \
src[8] = MS_LDQ_F32(src_data + 8 * src_step); \
src[9] = MS_LDQ_F32(src_data + 9 * src_step); \
src[10] = MS_LDQ_F32(src_data + 10 * src_step); \
src[11] = MS_LDQ_F32(src_data + 11 * src_step); \
src[12] = MS_LDQ_F32(src_data + 12 * src_step); \
src[13] = MS_LDQ_F32(src_data + 13 * src_step); \
src[14] = MS_LDQ_F32(src_data + 14 * src_step); \
src[15] = MS_LDQ_F32(src_data + 15 * src_step); \
src[16] = MS_LDQ_F32(src_data + 16 * src_step); \
src[17] = MS_LDQ_F32(src_data + 17 * src_step); \
src[18] = MS_LDQ_F32(src_data + 18 * src_step); \
src[19] = MS_LDQ_F32(src_data + 19 * src_step); \
src[20] = MS_LDQ_F32(src_data + 20 * src_step); \
src[21] = MS_LDQ_F32(src_data + 21 * src_step); \
src[22] = MS_LDQ_F32(src_data + 22 * src_step); \
src[23] = MS_LDQ_F32(src_data + 23 * src_step); \
src[24] = MS_LDQ_F32(src_data + 24 * src_step); \
src[25] = MS_LDQ_F32(src_data + 25 * src_step); \
src[26] = MS_LDQ_F32(src_data + 26 * src_step); \
src[27] = MS_LDQ_F32(src_data + 27 * src_step); \
src[28] = MS_LDQ_F32(src_data + 28 * src_step); \
src[29] = MS_LDQ_F32(src_data + 29 * src_step); \
src[30] = MS_LDQ_F32(src_data + 30 * src_step); \
src[31] = MS_LDQ_F32(src_data + 31 * src_step); \
src[32] = MS_LDQ_F32(src_data + 32 * src_step); \
src[33] = MS_LDQ_F32(src_data + 33 * src_step); \
src[34] = MS_LDQ_F32(src_data + 34 * src_step); \
src[35] = MS_LDQ_F32(src_data + 35 * src_step);

#define Load64Data \
src[0] = vld1q_f32(src_data + 0 * src_step); \
src[1] = vld1q_f32(src_data + 1 * src_step); \
src[2] = vld1q_f32(src_data + 2 * src_step); \
src[3] = vld1q_f32(src_data + 3 * src_step); \
src[4] = vld1q_f32(src_data + 4 * src_step); \
src[5] = vld1q_f32(src_data + 5 * src_step); \
src[6] = vld1q_f32(src_data + 6 * src_step); \
src[7] = vld1q_f32(src_data + 7 * src_step); \
src[8] = vld1q_f32(src_data + 8 * src_step); \
src[9] = vld1q_f32(src_data + 9 * src_step); \
src[10] = vld1q_f32(src_data + 10 * src_step); \
src[11] = vld1q_f32(src_data + 11 * src_step); \
src[12] = vld1q_f32(src_data + 12 * src_step); \
src[13] = vld1q_f32(src_data + 13 * src_step); \
src[14] = vld1q_f32(src_data + 14 * src_step); \
src[15] = vld1q_f32(src_data + 15 * src_step); \
src[16] = vld1q_f32(src_data + 16 * src_step); \
src[17] = vld1q_f32(src_data + 17 * src_step); \
src[18] = vld1q_f32(src_data + 18 * src_step); \
src[19] = vld1q_f32(src_data + 19 * src_step); \
src[20] = vld1q_f32(src_data + 20 * src_step); \
src[21] = vld1q_f32(src_data + 21 * src_step); \
src[22] = vld1q_f32(src_data + 22 * src_step); \
src[23] = vld1q_f32(src_data + 23 * src_step); \
src[24] = vld1q_f32(src_data + 24 * src_step); \
src[25] = vld1q_f32(src_data + 25 * src_step); \
src[26] = vld1q_f32(src_data + 26 * src_step); \
src[27] = vld1q_f32(src_data + 27 * src_step); \
src[28] = vld1q_f32(src_data + 28 * src_step); \
src[29] = vld1q_f32(src_data + 29 * src_step); \
src[30] = vld1q_f32(src_data + 30 * src_step); \
src[31] = vld1q_f32(src_data + 31 * src_step); \
src[32] = vld1q_f32(src_data + 32 * src_step); \
src[33] = vld1q_f32(src_data + 33 * src_step); \
src[34] = vld1q_f32(src_data + 34 * src_step); \
src[35] = vld1q_f32(src_data + 35 * src_step); \
src[36] = vld1q_f32(src_data + 36 * src_step); \
src[37] = vld1q_f32(src_data + 37 * src_step); \
src[38] = vld1q_f32(src_data + 38 * src_step); \
src[39] = vld1q_f32(src_data + 39 * src_step); \
src[40] = vld1q_f32(src_data + 40 * src_step); \
src[41] = vld1q_f32(src_data + 41 * src_step); \
src[42] = vld1q_f32(src_data + 42 * src_step); \
src[43] = vld1q_f32(src_data + 43 * src_step); \
src[44] = vld1q_f32(src_data + 44 * src_step); \
src[45] = vld1q_f32(src_data + 45 * src_step); \
src[46] = vld1q_f32(src_data + 46 * src_step); \
src[47] = vld1q_f32(src_data + 47 * src_step); \
src[48] = vld1q_f32(src_data + 48 * src_step); \
src[49] = vld1q_f32(src_data + 49 * src_step); \
src[50] = vld1q_f32(src_data + 50 * src_step); \
src[51] = vld1q_f32(src_data + 51 * src_step); \
src[52] = vld1q_f32(src_data + 52 * src_step); \
src[53] = vld1q_f32(src_data + 53 * src_step); \
src[54] = vld1q_f32(src_data + 54 * src_step); \
src[55] = vld1q_f32(src_data + 55 * src_step); \
src[56] = vld1q_f32(src_data + 56 * src_step); \
src[57] = vld1q_f32(src_data + 57 * src_step); \
src[58] = vld1q_f32(src_data + 58 * src_step); \
src[59] = vld1q_f32(src_data + 59 * src_step); \
src[60] = vld1q_f32(src_data + 60 * src_step); \
src[61] = vld1q_f32(src_data + 61 * src_step); \
src[62] = vld1q_f32(src_data + 62 * src_step); \
src[63] = vld1q_f32(src_data + 63 * src_step);
#define Load64Data \
src[0] = MS_LDQ_F32(src_data + 0 * src_step); \
src[1] = MS_LDQ_F32(src_data + 1 * src_step); \
src[2] = MS_LDQ_F32(src_data + 2 * src_step); \
src[3] = MS_LDQ_F32(src_data + 3 * src_step); \
src[4] = MS_LDQ_F32(src_data + 4 * src_step); \
src[5] = MS_LDQ_F32(src_data + 5 * src_step); \
src[6] = MS_LDQ_F32(src_data + 6 * src_step); \
src[7] = MS_LDQ_F32(src_data + 7 * src_step); \
src[8] = MS_LDQ_F32(src_data + 8 * src_step); \
src[9] = MS_LDQ_F32(src_data + 9 * src_step); \
src[10] = MS_LDQ_F32(src_data + 10 * src_step); \
src[11] = MS_LDQ_F32(src_data + 11 * src_step); \
src[12] = MS_LDQ_F32(src_data + 12 * src_step); \
src[13] = MS_LDQ_F32(src_data + 13 * src_step); \
src[14] = MS_LDQ_F32(src_data + 14 * src_step); \
src[15] = MS_LDQ_F32(src_data + 15 * src_step); \
src[16] = MS_LDQ_F32(src_data + 16 * src_step); \
src[17] = MS_LDQ_F32(src_data + 17 * src_step); \
src[18] = MS_LDQ_F32(src_data + 18 * src_step); \
src[19] = MS_LDQ_F32(src_data + 19 * src_step); \
src[20] = MS_LDQ_F32(src_data + 20 * src_step); \
src[21] = MS_LDQ_F32(src_data + 21 * src_step); \
src[22] = MS_LDQ_F32(src_data + 22 * src_step); \
src[23] = MS_LDQ_F32(src_data + 23 * src_step); \
src[24] = MS_LDQ_F32(src_data + 24 * src_step); \
src[25] = MS_LDQ_F32(src_data + 25 * src_step); \
src[26] = MS_LDQ_F32(src_data + 26 * src_step); \
src[27] = MS_LDQ_F32(src_data + 27 * src_step); \
src[28] = MS_LDQ_F32(src_data + 28 * src_step); \
src[29] = MS_LDQ_F32(src_data + 29 * src_step); \
src[30] = MS_LDQ_F32(src_data + 30 * src_step); \
src[31] = MS_LDQ_F32(src_data + 31 * src_step); \
src[32] = MS_LDQ_F32(src_data + 32 * src_step); \
src[33] = MS_LDQ_F32(src_data + 33 * src_step); \
src[34] = MS_LDQ_F32(src_data + 34 * src_step); \
src[35] = MS_LDQ_F32(src_data + 35 * src_step); \
src[36] = MS_LDQ_F32(src_data + 36 * src_step); \
src[37] = MS_LDQ_F32(src_data + 37 * src_step); \
src[38] = MS_LDQ_F32(src_data + 38 * src_step); \
src[39] = MS_LDQ_F32(src_data + 39 * src_step); \
src[40] = MS_LDQ_F32(src_data + 40 * src_step); \
src[41] = MS_LDQ_F32(src_data + 41 * src_step); \
src[42] = MS_LDQ_F32(src_data + 42 * src_step); \
src[43] = MS_LDQ_F32(src_data + 43 * src_step); \
src[44] = MS_LDQ_F32(src_data + 44 * src_step); \
src[45] = MS_LDQ_F32(src_data + 45 * src_step); \
src[46] = MS_LDQ_F32(src_data + 46 * src_step); \
src[47] = MS_LDQ_F32(src_data + 47 * src_step); \
src[48] = MS_LDQ_F32(src_data + 48 * src_step); \
src[49] = MS_LDQ_F32(src_data + 49 * src_step); \
src[50] = MS_LDQ_F32(src_data + 50 * src_step); \
src[51] = MS_LDQ_F32(src_data + 51 * src_step); \
src[52] = MS_LDQ_F32(src_data + 52 * src_step); \
src[53] = MS_LDQ_F32(src_data + 53 * src_step); \
src[54] = MS_LDQ_F32(src_data + 54 * src_step); \
src[55] = MS_LDQ_F32(src_data + 55 * src_step); \
src[56] = MS_LDQ_F32(src_data + 56 * src_step); \
src[57] = MS_LDQ_F32(src_data + 57 * src_step); \
src[58] = MS_LDQ_F32(src_data + 58 * src_step); \
src[59] = MS_LDQ_F32(src_data + 59 * src_step); \
src[60] = MS_LDQ_F32(src_data + 60 * src_step); \
src[61] = MS_LDQ_F32(src_data + 61 * src_step); \
src[62] = MS_LDQ_F32(src_data + 62 * src_step); \
src[63] = MS_LDQ_F32(src_data + 63 * src_step);

InputTransFunc GetInputTransFunc(int input_unit);

@@ -171,67 +171,67 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,

OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);

#define Store4Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + dst_step * out_c, m[2]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[3]);
#define Store4Data \
MS_STQ_F32(dst_data, m[0]); \
MS_STQ_F32(dst_data + out_c, m[1]); \
MS_STQ_F32(dst_data + dst_step * out_c, m[2]); \
MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[3]);

#define Store9Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + 2 * out_c, m[2]); \
vst1q_f32(dst_data + dst_step * out_c, m[3]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[4]); \
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \
vst1q_f32(dst_data + 2 * dst_step * out_c, m[6]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]);
#define Store9Data \
MS_STQ_F32(dst_data, m[0]); \
MS_STQ_F32(dst_data + out_c, m[1]); \
MS_STQ_F32(dst_data + 2 * out_c, m[2]); \
MS_STQ_F32(dst_data + dst_step * out_c, m[3]); \
MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[4]); \
MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[6]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]);

#define Store16Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + 2 * out_c, m[2]); \
vst1q_f32(dst_data + 3 * out_c, m[3]); \
vst1q_f32(dst_data + dst_step * out_c, m[4]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[5]); \
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \
vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * out_c, m[8]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \
vst1q_f32(dst_data + 3 * dst_step * out_c, m[12]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]);
#define Store16Data \
MS_STQ_F32(dst_data, m[0]); \
MS_STQ_F32(dst_data + out_c, m[1]); \
MS_STQ_F32(dst_data + 2 * out_c, m[2]); \
MS_STQ_F32(dst_data + 3 * out_c, m[3]); \
MS_STQ_F32(dst_data + dst_step * out_c, m[4]); \
MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[5]); \
MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \
MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[8]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[12]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]);

#define Store25Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + 2 * out_c, m[2]); \
vst1q_f32(dst_data + 3 * out_c, m[3]); \
vst1q_f32(dst_data + 4 * out_c, m[4]); \
vst1q_f32(dst_data + dst_step * out_c, m[5]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[6]); \
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \
vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \
vst1q_f32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * out_c, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * out_c, m[15]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \
vst1q_f32(dst_data + 4 * dst_step * out_c, m[20]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
#define Store25Data \
MS_STQ_F32(dst_data, m[0]); \
MS_STQ_F32(dst_data + out_c, m[1]); \
MS_STQ_F32(dst_data + 2 * out_c, m[2]); \
MS_STQ_F32(dst_data + 3 * out_c, m[3]); \
MS_STQ_F32(dst_data + 4 * out_c, m[4]); \
MS_STQ_F32(dst_data + dst_step * out_c, m[5]); \
MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[6]); \
MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \
MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \
MS_STQ_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[10]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \
MS_STQ_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[15]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \
MS_STQ_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \
MS_STQ_F32(dst_data + 4 * dst_step * out_c, m[20]); \
MS_STQ_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \
MS_STQ_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \
MS_STQ_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
MS_STQ_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);

void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);


Loading…
Cancel
Save