Browse Source

!9972 [MSLITE] Optimize fp32 matmul for Arm v7a

From: @zhanyuan1
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
dbdf15cca9
11 changed files with 481 additions and 60 deletions
  1. +405
    -0
      mindspore/lite/nnacl/assembly/arm32/MatmulFp32Opt12x4.S
  2. +2
    -2
      mindspore/lite/nnacl/fp32/conv_fp32.c
  3. +2
    -0
      mindspore/lite/nnacl/fp32/matmul_fp32.c
  4. +2
    -0
      mindspore/lite/nnacl/fp32/matmul_fp32.h
  5. +4
    -0
      mindspore/lite/nnacl/fp32_grad/gemm.c
  6. +2
    -0
      mindspore/lite/nnacl/matmul_parameter.h
  7. +11
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc
  8. +5
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc
  9. +13
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc
  10. +34
    -46
      mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc
  11. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.h

+ 405
- 0
mindspore/lite/nnacl/assembly/arm32/MatmulFp32Opt12x4.S View File

@@ -0,0 +1,405 @@
#ifdef ENABLE_ARM32
.text
.align 5
.global MatmulFloatNeon32Opt12x4
#ifndef __APPLE__
.type MatmulFloatNeon32Opt12x4, %function
#endif

// void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, size_t stride, size_t writeMode)
// r0: a
// r1: b
// r2: c
// r3: bias
// r4: act_type
// r5: depth
// r6: row
// r7: col
// r8: stride
// lr: OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2

MatmulFloatNeon32Opt12x4:
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r0-r8, r10, r11, lr}
add sp, sp, #48

ldr r5, [sp, #4]
ldr r6, [sp, #8]
ldr r7, [sp, #12]
ldr r8, [sp, #16]

mov lr, #48 // sizeof(float) * 12
mul r12, r5, lr // block stride of lhs: sizeof(float) * 12 * depth
mov lr, #4
mul r8, r8, lr // stride * sizeof(float)

LoopRow:
ldr r1, [sp, #-44] // reload rhs ptr
ldr r7, [sp, #12] // reload rhs col
ldr r3, [sp, #-36] // reload bias ptr

LoopCol:
ldr r2, [sp, #-40] // reload dst ptr
ldr r0, [sp, #-48] // reload lhs ptr
ldr r5, [sp, #4] // reload depth
vld1.32 {q3}, [r1]!
vld1.32 {q0, q1}, [r0]!
vmul.f32 q4, q3, d0[0]
vmul.f32 q5, q3, d0[1]
vmul.f32 q6, q3, d1[0]
vld1.32 {q2}, [r0]!
vmul.f32 q7, q3, d1[1]

vmul.f32 q8, q3, d2[0]
vmul.f32 q9, q3, d2[1]
vmul.f32 q10, q3, d3[0]
vmul.f32 q11, q3, d3[1]

vmul.f32 q12, q3, d4[0]
vmul.f32 q13, q3, d4[1]
vmul.f32 q14, q3, d5[0]
vmul.f32 q15, q3, d5[1]

subs r5, r5, #1
beq Bias

LoopDepth:
vld1.32 {q3}, [r1]!
vld1.32 {q0, q1}, [r0]!
vmla.f32 q4, q3, d0[0]
vmla.f32 q5, q3, d0[1]
vmla.f32 q6, q3, d1[0]
vld1.32 {q2}, [r0]!
vmla.f32 q7, q3, d1[1]

vmla.f32 q8, q3, d2[0]
vmla.f32 q9, q3, d2[1]
vmla.f32 q10, q3, d3[0]
vmla.f32 q11, q3, d3[1]

vmla.f32 q12, q3, d4[0]
vmla.f32 q13, q3, d4[1]
vmla.f32 q14, q3, d5[0]
vmla.f32 q15, q3, d5[1]

subs r5, r5, #1
bne LoopDepth

Bias:
cmp r3, #0
beq Activation
vld1.32 {q0}, [r3]!
vadd.f32 q4, q4, q0
vadd.f32 q5, q5, q0
vadd.f32 q6, q6, q0
vadd.f32 q7, q7, q0
vadd.f32 q8, q8, q0
vadd.f32 q9, q9, q0
vadd.f32 q10, q10, q0
vadd.f32 q11, q11, q0
vadd.f32 q12, q12, q0
vadd.f32 q13, q13, q0
vadd.f32 q14, q14, q0
vadd.f32 q15, q15, q0

Activation:
ldr lr, [sp]
cmp lr, #3
beq Relu6
cmp lr, #1
beq Relu
b Write

Relu6:
vmov.i32 q2, #6
vcvt.f32.s32 q2, q2
vmin.f32 q4, q4, q2
vmin.f32 q5, q5, q2
vmin.f32 q6, q6, q2
vmin.f32 q7, q7, q2
vmin.f32 q8, q8, q2
vmin.f32 q9, q9, q2
vmin.f32 q10, q10, q2
vmin.f32 q11, q11, q2
vmin.f32 q12, q12, q2
vmin.f32 q13, q13, q2
vmin.f32 q14, q14, q2
vmin.f32 q15, q15, q2

Relu:
veor q3, q3, q3
vmax.f32 q4, q4, q3
vmax.f32 q5, q5, q3
vmax.f32 q6, q6, q3
vmax.f32 q7, q7, q3
vmax.f32 q8, q8, q3
vmax.f32 q9, q9, q3
vmax.f32 q10, q10, q3
vmax.f32 q11, q11, q3
vmax.f32 q12, q12, q3
vmax.f32 q13, q13, q3
vmax.f32 q14, q14, q3
vmax.f32 q15, q15, q3

Write:
cmp r7, #1
beq Write1
cmp r7, #2
beq Write2
cmp r7, #3
beq Write3
b Write4

Write1:
add lr, r2, #4
str lr, [sp, #-40]
vst1.32 d8[0], [r2]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
vst1.32 d10[0], [r2]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
vst1.32 d12[0], [r2]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
vst1.32 d14[0], [r2]
cmp r6, #4
beq WriteEnd
add r2, r2, r8
vst1.32 d16[0], [r2]
cmp r6, #5
beq WriteEnd
add r2, r2, r8
vst1.32 d18[0], [r2]
cmp r6, #6
beq WriteEnd
add r2, r2, r8
vst1.32 d20[0], [r2]
cmp r6, #7
beq WriteEnd
add r2, r2, r8
vst1.32 d22[0], [r2]
cmp r6, #8
beq WriteEnd
add r2, r2, r8
vst1.32 d24[0], [r2]
cmp r6, #9
beq WriteEnd
add r2, r2, r8
vst1.32 d26[0], [r2]
cmp r6, #10
beq WriteEnd
add r2, r2, r8
vst1.32 d28[0], [r2]
cmp r6, #11
beq WriteEnd
add r2, r2, r8
vst1.32 d30[0], [r2]
add r2, r2, r8
add r2, r2, #4
b WriteEnd
Write2:
add lr, r2, #8
str lr, [sp, #-40]
vst1.32 d8, [r2]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
vst1.32 d10, [r2]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
vst1.32 d12, [r2]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
vst1.32 d14, [r2]
cmp r6, #4
beq WriteEnd
add r2, r2, r8
vst1.32 d16, [r2]
cmp r6, #5
beq WriteEnd
add r2, r2, r8
vst1.32 d18, [r2]
cmp r6, #6
beq WriteEnd
add r2, r2, r8
vst1.32 d20, [r2]
cmp r6, #7
beq WriteEnd
add r2, r2, r8
vst1.32 d22, [r2]
cmp r6, #8
beq WriteEnd
add r2, r2, r8
vst1.32 d24, [r2]
cmp r6, #9
beq WriteEnd
add r2, r2, r8
vst1.32 d26, [r2]
cmp r6, #10
beq WriteEnd
add r2, r2, r8
vst1.32 d28, [r2]
cmp r6, #11
beq WriteEnd
add r2, r2, r8
vst1.32 d30, [r2]
add r2, r2, r8
add r2, r2, #8
b WriteEnd
Write3:
add lr, r2, #12
str lr, [sp, #-40]
add r4, r2, #8
vst1.32 d8, [r2]
vst1.32 d9[0], [r4]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d10, [r2]
vst1.32 d11[0], [r4]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d12, [r2]
vst1.32 d13[0], [r4]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d14, [r2]
vst1.32 d15[0], [r4]
cmp r6, #4
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d16, [r2]
vst1.32 d17[0], [r4]
cmp r6, #5
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d18, [r2]
vst1.32 d19[0], [r4]
cmp r6, #6
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d20, [r2]
vst1.32 d21[0], [r4]
cmp r6, #7
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d22, [r2]
vst1.32 d23[0], [r4]
cmp r6, #8
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d24, [r2]
vst1.32 d25[0], [r4]
cmp r6, #9
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d26, [r2]
vst1.32 d27[0], [r4]
cmp r6, #10
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d28, [r2]
vst1.32 d29[0], [r4]
cmp r6, #11
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d30, [r2]
vst1.32 d31[0], [r4]
add r2, r2, r8
add r2, r2, #12
b WriteEnd
Write4:
add lr, r2, #16
str lr, [sp, #-40]
vst1.32 q4, [r2]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
vst1.32 q5, [r2]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
vst1.32 q6, [r2]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
vst1.32 q7, [r2]
cmp r6, #4
beq WriteEnd
add r2, r2, r8
vst1.32 q8, [r2]
cmp r6, #5
beq WriteEnd
add r2, r2, r8
vst1.32 q9, [r2]
cmp r6, #6
beq WriteEnd
add r2, r2, r8
vst1.32 q10, [r2]
cmp r6, #7
beq WriteEnd
add r2, r2, r8
vst1.32 q11, [r2]
cmp r6, #8
beq WriteEnd
add r2, r2, r8
vst1.32 q12, [r2]
cmp r6, #9
beq WriteEnd
add r2, r2, r8
vst1.32 q13, [r2]
cmp r6, #10
beq WriteEnd
add r2, r2, r8
vst1.32 q14, [r2]
cmp r6, #11
beq WriteEnd
add r2, r2, r8
vst1.32 q15, [r2]
add r2, r2, r8
add r2, r2, #16
b WriteEnd
WriteEnd:
cmp r7, #4
ble LoopColEnd
sub r7, r7, #4 // rhs col - 4
b LoopCol

LoopColEnd:
ldr r0, [sp, #-48]
add r0, r0, r12 // lhs ptr + stride
str r0, [sp, #-48]
mov lr, #4
ldr r7, [sp, #12] // reload rhs col
mul lr, lr, r7
sub r2, r2, lr
str r2, [sp, #-40]
cmp r6, #12
ble LoopRowEnd
sub r6, r6, #12 // lhs row - 12
b LoopRow

LoopRowEnd:
sub sp, sp, #48
pop {r0-r8, r10, r11, pc}
#endif

+ 2
- 2
mindspore/lite/nnacl/fp32/conv_fp32.c View File

@@ -28,7 +28,7 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
int output_count = conv_param->output_h_ * conv_param->output_w_;
#ifdef ENABLE_AVX
const int cal_num = C6NUM;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
const int cal_num = C4NUM;
#else
const int cal_num = C12NUM;
@@ -52,7 +52,7 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
float *gemm_output = output_data + out_offset;
#ifdef ENABLE_AVX
RowMajor2Col6Major(gemm_input, col_major_gemm_input, cal_num, deep);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
#else
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);


+ 2
- 0
mindspore/lite/nnacl/fp32/matmul_fp32.c View File

@@ -874,6 +874,8 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
#elif ENABLE_ARM32
if (out_type == OutType_C8) {
MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else if (out_type == OutType_Nhwc) {
MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1);
} else {
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
}


+ 2
- 0
mindspore/lite/nnacl/fp32/matmul_fp32.h View File

@@ -53,6 +53,8 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi
int col, int stride, size_t writeNhwc, size_t WriteWino);
void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode);
void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
int row, int col, int stride, int write_mode);
#elif ENABLE_SSE
#include <x86intrin.h>
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,


+ 4
- 0
mindspore/lite/nnacl/fp32_grad/gemm.c View File

@@ -544,7 +544,11 @@ void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const floa
}
}
if (incremental) output = fworkspace;
#ifdef ENABLE_ARM32
MatmulFloatNeon32Opt(mat_a_input, mat_b_input, output, gcb->bias, (int)gcb->atype, K, M, N, ldc, 1);
#else
MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc);
#endif
if (incremental) addv(output, mat_c, beta, M, N, ldc);
gcb->mat_a = mat_a_input;
gcb->mat_b = mat_b_input;


+ 2
- 0
mindspore/lite/nnacl/matmul_parameter.h View File

@@ -46,10 +46,12 @@ typedef struct MatMulParameter {
int row_8_;
int row_12_;
int row_16_;
int row_align_;
int col_2_;
int col_4_;
int col_8_;
int col_16_;
int col_align_;
int deep_;
int deep_4_;
int deep_16_;


+ 11
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc View File

@@ -74,6 +74,8 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {

#ifdef ENABLE_AVX
int col_tile = C16NUM;
#elif defined(ENABLE_ARM32)
int col_tile = C4NUM;
#else
int col_tile = C8NUM;
#endif
@@ -100,6 +102,9 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {
#ifdef ENABLE_AVX
RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
#else
RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
@@ -111,7 +116,7 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
int hw_tile = C12NUM;
#ifdef ENABLE_AVX
hw_tile = C6NUM;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
hw_tile = C4NUM;
#endif
if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
@@ -121,6 +126,8 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
} else {
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#elif defined(ENABLE_ARM32)
int col_tile = C4NUM;
#else
int col_tile = C8NUM;
#endif
@@ -195,7 +202,7 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) {

#if ENABLE_AVX
RowMajor2Col6Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#else
RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
@@ -225,7 +232,7 @@ int Convolution1x1CPUKernel::Run() {
#ifdef ENABLE_AVX
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_6_ * matmul_param_->deep_ * sizeof(float)));
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
@@ -251,7 +258,7 @@ int Convolution1x1CPUKernel::Run() {
} else {
#ifdef ENABLE_AVX
RowMajor2Col6Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#else
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);


+ 5
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc View File

@@ -44,6 +44,8 @@ int ConvolutionCPUKernel::InitWeightBias() {
int kernel_plane = filter_tensor->Height() * filter_tensor->Width();
#ifdef ENABLE_AVX
const int oc_block = C16NUM;
#elif ENABLE_ARM32
const int oc_block = C4NUM;
#else
const int oc_block = C8NUM;
#endif
@@ -59,6 +61,8 @@ int ConvolutionCPUKernel::InitWeightBias() {
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
#ifdef ENABLE_AVX
RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#elif ENABLE_ARM32
RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#else
RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#endif
@@ -84,7 +88,7 @@ int ConvolutionCPUKernel::InitTmpBuffer() {

#ifdef ENABLE_AVX
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_;
#elif ENABLE_ARM32 || ENABLE_SSE
#elif ENABLE_SSE
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_;
#else
int unit_size =


+ 13
- 7
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc View File

@@ -53,16 +53,18 @@ int FullconnectionCPUKernel::ReSize() {

#ifdef ENABLE_AVX
int col_tile = C16NUM;
#elif defined(ENABLE_ARM32)
int col_tile = C4NUM;
#else
int col_tile = C8NUM;
#endif
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, col_tile);
fc_param_->col_align_ = UP_ROUND(fc_param_->col_, col_tile);
fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM);
fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM);

thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, col_tile));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, col_tile), thread_count_);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_align_, col_tile));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_align_, col_tile), thread_count_);

#ifdef ENABLE_ARM
if (fc_param_->row_ == 1) {
@@ -72,7 +74,7 @@ int FullconnectionCPUKernel::ReSize() {
}
#endif
if (in_tensors_.size() == 3) {
int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_8_;
int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_align_;
bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
@@ -83,7 +85,7 @@ int FullconnectionCPUKernel::ReSize() {

#ifdef ENABLE_AVX
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_6_;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_;
#else
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_;
@@ -94,7 +96,7 @@ int FullconnectionCPUKernel::ReSize() {
}
memset(a_pack_ptr_, 0, row_tmp * fc_param_->deep_ * sizeof(float));

int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_8_;
int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_align_;
b_pack_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * fc_param_->deep_ * sizeof(float)));
if (b_pack_ptr_ == nullptr) {
FreeBuf();
@@ -130,7 +132,7 @@ void FullconnectionCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr)

#ifdef ENABLE_AVX
RowMajor2Col6Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#elif defined(ENABLE_SSE)
RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
#else
RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
@@ -144,6 +146,8 @@ void FullconnectionCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr)
}
#ifdef ENABLE_AVX
RowMajor2Col16Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
#else
RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
#endif
@@ -162,6 +166,8 @@ int FcFp32MatmulRun(void *cdata, int task_id) {
int FullconnectionCPUKernel::DoMatmul(int task_id) {
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#elif defined(ENABLE_ARM32)
int col_tile = C4NUM;
#else
int col_tile = C8NUM;
#endif


+ 34
- 46
mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc View File

@@ -74,17 +74,15 @@ int MatmulCPUKernel::MallocMatrixABuffer() {
}
#endif
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_4_ = UP_ROUND(params_->row_, C4NUM);
params_->row_6_ = UP_ROUND(params_->row_, C6NUM);
params_->row_12_ = UP_ROUND(params_->row_, C12NUM);

#ifdef ENABLE_AVX
int row_tmp = is_vector_a_ ? 1 : params_->row_6_;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
int row_tmp = is_vector_a_ ? 1 : params_->row_4_;
params_->row_align_ = UP_ROUND(params_->row_, C6NUM);
#elif defined(ENABLE_SSE)
params_->row_align_ = UP_ROUND(params_->row_, C4NUM);
#else
int row_tmp = is_vector_a_ ? 1 : params_->row_12_;
params_->row_align_ = UP_ROUND(params_->row_, C12NUM);
#endif

int row_tmp = is_vector_a_ ? 1 : params_->row_align_;
if (params_->a_const_) {
a_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * row_tmp * params_->deep_ * sizeof(float)));
} else {
@@ -109,17 +107,12 @@ int MatmulCPUKernel::MallocMatrixBBuffer() {
for (size_t i = 0; i < b_shape.size() - 2; ++i) {
batch *= b_shape[i];
}
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
params_->batch = batch;
params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1];
params_->col_8_ = UP_ROUND(params_->col_, col_tile);
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2];

int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_;
int col_tmp = is_vector_a_ ? params_->col_ : params_->col_align_;
if (params_->b_const_) {
b_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * col_tmp * params_->deep_ * sizeof(float)));
} else {
@@ -131,8 +124,8 @@ int MatmulCPUKernel::MallocMatrixBBuffer() {
return RET_MEMORY_FAILED;
}

thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, col_tile));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, col_tile), thread_count_);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_align_, col_tile_));
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
return RET_OK;
}

@@ -142,13 +135,8 @@ int MatmulCPUKernel::InitBias() {
params_->col_ = params_->b_const_
? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1))
: (c_shape.at(c_shape.size() - 1));
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
params_->col_8_ = UP_ROUND(params_->col_, col_tile);
auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_;
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_align_;
if (bias_ptr_ == nullptr) {
bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float)));
if (bias_ptr_ == nullptr) {
@@ -184,22 +172,20 @@ void MatmulCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) {

for (int i = 0; i < params_->batch; i++) {
const float *src = src_ptr + i * params_->deep_ * params_->row_;
float *dst = dst_ptr + i * params_->deep_ * params_->row_align_;
#ifdef ENABLE_AVX
float *dst = dst_ptr + i * params_->deep_ * params_->row_6_;
if (params_->a_transpose_) {
RowMajor2Row6Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col6Major(src, dst, params_->row_, params_->deep_);
}
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
float *dst = dst_ptr + i * params_->deep_ * params_->row_4_;
#elif defined(ENABLE_SSE)
if (params_->a_transpose_) {
RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col4Major(src, dst, params_->row_, params_->deep_);
}
#else
float *dst = dst_ptr + i * params_->deep_ * params_->row_12_;
if (params_->a_transpose_) {
RowMajor2Row12Major(src, dst, params_->deep_, params_->row_);
} else {
@@ -226,13 +212,19 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) {

for (int i = 0; i < params_->batch; i++) {
const float *src = src_ptr + i * params_->deep_ * params_->col_;
float *dst = dst_ptr + i * params_->deep_ * params_->col_8_;
float *dst = dst_ptr + i * params_->deep_ * params_->col_align_;
#ifdef ENABLE_AVX
if (params_->b_transpose_) {
RowMajor2Col16Major(src, dst, params_->col_, params_->deep_);
} else {
RowMajor2Row16Major(src, dst, params_->deep_, params_->col_);
}
#elif defined(ENABLE_ARM32)
if (params_->b_transpose_) {
RowMajor2Col4Major(src, dst, params_->col_, params_->deep_);
} else {
RowMajor2Row4Major(src, dst, params_->deep_, params_->col_);
}
#else
if (params_->b_transpose_) {
RowMajor2Col8Major(src, dst, params_->col_, params_->deep_);
@@ -245,6 +237,13 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) {
}

int MatmulCPUKernel::Init() {
#ifdef ENABLE_AVX
col_tile_ = C16NUM;
#elif defined(ENABLE_ARM32)
col_tile_ = C4NUM;
#else
col_tile_ = C8NUM;
#endif
params_->a_const_ = (in_tensors_.at(0)->data_c() != nullptr);
params_->b_const_ = (in_tensors_.at(1)->data_c() != nullptr);
if (params_->a_const_) {
@@ -275,18 +274,13 @@ int MatmulCPUKernel::Init() {
}

int MatmulCPUKernel::RunImpl(int task_id) {
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
int cur_oc = MSMIN(thread_stride_ * col_tile, params_->col_ - task_id * thread_stride_ * col_tile);
int cur_oc = MSMIN(thread_stride_ * col_tile_, params_->col_ - task_id * thread_stride_ * col_tile_);
if (cur_oc <= 0) {
return RET_OK;
}
auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile * params_->deep_;
auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile;
auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile : NULL;
auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile_ * params_->deep_;
auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile_;
auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile_ : NULL;
MS_ASSERT(cur_a_ptr_);
MS_ASSERT(b);
MS_ASSERT(c);
@@ -356,14 +350,8 @@ int MatmulCPUKernel::Run() {
cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_;
cur_c_ptr_ = c_src + i * params_->row_ * params_->col_;
} else {
#ifdef ENABLE_AVX
cur_a_ptr_ = a_ptr_ + i * params_->row_6_ * params_->deep_;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_;
#else
cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_;
#endif
cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_8_;
cur_a_ptr_ = a_ptr_ + i * params_->row_align_ * params_->deep_;
cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_align_;
cur_c_ptr_ = c_src + i * params_->row_ * params_->col_;
}
auto ret = ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_);


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.h View File

@@ -54,6 +54,7 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel {
float *cur_b_ptr_ = nullptr;
float *cur_c_ptr_ = nullptr;
bool is_vector_a_ = false;
int col_tile_ = 0;
};
} // namespace mindspore::kernel



Loading…
Cancel
Save