| @@ -31,10 +31,6 @@ Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { | |||||
| free(pack_input_); | free(pack_input_); | ||||
| pack_input_ = nullptr; | pack_input_ = nullptr; | ||||
| } | } | ||||
| if (pack_output_ != nullptr) { | |||||
| free(pack_output_); | |||||
| pack_output_ = nullptr; | |||||
| } | |||||
| if (pre_trans_input_ && input_ptr_ != nullptr) { | if (pre_trans_input_ && input_ptr_ != nullptr) { | ||||
| free(input_ptr_); | free(input_ptr_); | ||||
| input_ptr_ = nullptr; | input_ptr_ = nullptr; | ||||
| @@ -112,13 +108,6 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)); | memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)); | ||||
| pack_output_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float))); | |||||
| if (pack_output_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -157,7 +146,7 @@ int Convolution1x1CPUKernel::Init() { | |||||
| } | } | ||||
| int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | ||||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_); | |||||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||||
| if (cur_oc <= 0) { | if (cur_oc <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -165,23 +154,12 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | |||||
| auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id; | auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id; | ||||
| MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, | MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, | ||||
| pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_, | |||||
| matmul_param_->deep_, matmul_param_->row_8_, cur_oc); | |||||
| output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, | |||||
| matmul_param_->row_, cur_oc, matmul_param_->col_, true); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Convolution1x1CPUKernel::DoConv1x1Post(int task_id) { | |||||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||||
| if (cur_oc <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| float *src = pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_; | |||||
| float *dst = output_ptr_ + task_id * thread_stride_; | |||||
| Row8x8Major2RowMajor(src, dst, matmul_param_->row_, cur_oc, matmul_param_->col_); | |||||
| return RET_OK; | |||||
| } | |||||
| int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | ||||
| auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); | auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); | ||||
| auto error_code = conv1x1->DoConv1x1(task_id); | auto error_code = conv1x1->DoConv1x1(task_id); | ||||
| @@ -192,12 +170,6 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Convolution1x1Post(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); | |||||
| conv1x1->DoConv1x1Post(task_id); | |||||
| return RET_OK; | |||||
| } | |||||
| int Convolution1x1CPUKernel::Run() { | int Convolution1x1CPUKernel::Run() { | ||||
| auto prepare_ret = Prepare(); | auto prepare_ret = Prepare(); | ||||
| if (prepare_ret != RET_OK) { | if (prepare_ret != RET_OK) { | ||||
| @@ -216,8 +188,6 @@ int Convolution1x1CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| LiteBackendParallelLaunch(Convolution1x1Post, this, thread_count_); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -46,7 +46,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| public: | public: | ||||
| int DoConv1x1(int task_id); | int DoConv1x1(int task_id); | ||||
| int DoConv1x1Post(int task_id); | |||||
| private: | private: | ||||
| int InitConv1x1Param(); | int InitConv1x1Param(); | ||||
| @@ -61,7 +60,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int thread_stride_ = 0; | int thread_stride_ = 0; | ||||
| float *weight_ptr_ = nullptr; | float *weight_ptr_ = nullptr; | ||||
| float *pack_input_ = nullptr; | float *pack_input_ = nullptr; | ||||
| float *pack_output_ = nullptr; | |||||
| float *input_ptr_ = nullptr; | float *input_ptr_ = nullptr; | ||||
| float *output_ptr_ = nullptr; | float *output_ptr_ = nullptr; | ||||
| }; | }; | ||||
| @@ -152,7 +152,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { | |||||
| MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | ||||
| tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No, | tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No, | ||||
| matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_); | |||||
| matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, matmul_param_->col_, false); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -104,7 +104,7 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) { | |||||
| MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, | MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, | ||||
| c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_, | c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_, | ||||
| bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_, | bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_, | ||||
| cur_oc * 8); | |||||
| cur_oc * 8, 0, false); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -77,7 +77,7 @@ int MatmulCPUKernel::RunImpl(int task_id) { | |||||
| } | } | ||||
| auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; | auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; | ||||
| auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; | auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; | ||||
| MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8); | |||||
| MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -640,7 +640,7 @@ IndirectGemmStart: | |||||
| add x15, x15, x7 | add x15, x15, x7 | ||||
| str s30, [x15] | str s30, [x15] | ||||
| add x0, x0, #4 | add x0, x0, #4 | ||||
| b WriteEnd | |||||
| b WriteEndHalf | |||||
| Write2: | Write2: | ||||
| dup s17, v16.s[1] | dup s17, v16.s[1] | ||||
| stp s16, s17, [x15] | stp s16, s17, [x15] | ||||
| @@ -666,7 +666,7 @@ IndirectGemmStart: | |||||
| dup s31, v30.s[1] | dup s31, v30.s[1] | ||||
| stp s30, s31, [x15] | stp s30, s31, [x15] | ||||
| add x0, x0, #8 | add x0, x0, #8 | ||||
| b WriteEnd | |||||
| b WriteEndHalf | |||||
| Write3: | Write3: | ||||
| add x17, x15, #8 | add x17, x15, #8 | ||||
| dup s17, v16.s[1] | dup s17, v16.s[1] | ||||
| @@ -27,7 +27,7 @@ | |||||
| // accumulators 8x8 block | // accumulators 8x8 block | ||||
| // | // | ||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||
| //OptLoopMul4 RM 1x8 block | |||||
| //OptLoopMul4 RM 4x8 block | |||||
| // /--------------------------------------------\ | // /--------------------------------------------\ | ||||
| // |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | | // |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | | ||||
| // |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| | // |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| | ||||
| @@ -46,7 +46,8 @@ | |||||
| // accumulators 8x8 block | // accumulators 8x8 block | ||||
| ///////////////////////////////////////////////////////////////////////////////// | ///////////////////////////////////////////////////////////////////////////////// | ||||
| // | // | ||||
| // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col) | |||||
| // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||||
| // int row, int col, int stride, bool write_nhwc) | |||||
| // x0: a | // x0: a | ||||
| // x1: b | // x1: b | ||||
| // x2: c | // x2: c | ||||
| @@ -55,30 +56,30 @@ | |||||
| // w5: depth | // w5: depth | ||||
| // w6: row | // w6: row | ||||
| // w7: col | // w7: col | ||||
| // w17: stride | |||||
| // w13: writeC8 | |||||
| MatmulFloatNeon64: | MatmulFloatNeon64: | ||||
| sub sp, sp, #128 | sub sp, sp, #128 | ||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| mov w9, #0 // rm col offset | |||||
| mov w10, #0 // lm row offset | |||||
| mov w18, #32 // sizeof(float)*8 | |||||
| mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth | |||||
| mov x11, x3 // bias flag | |||||
| mov w18, #32 // sizeof(float) * 8 | |||||
| mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth | |||||
| mov x11, x3 // bias flag | |||||
| mov x18, #4 | |||||
| ldr x17, [sp] | |||||
| mul x17, x17, x18 | |||||
| L1: | L1: | ||||
| cmp w9, w7 | |||||
| beq End1 | |||||
| mov w10, w6 // reload lhs row | |||||
| mov x12, x0 // reload lhs ptr | |||||
| mov x18, x2 // reload dst ptr | |||||
| mov w10, #0 // reset lm row offset | |||||
| mov x12, x0 // reload lm ptr | |||||
| L2: | L2: | ||||
| cmp w10, w6 | |||||
| beq End2 | |||||
| mov x16, x1 // reload rm ptr | |||||
| mov w13, w5 // reload depth | |||||
| mov x14, x3 // reload bias ptr | |||||
| mov x16, x1 // reload rhs ptr | |||||
| mov w13, w5 // reload depth | |||||
| mov x14, x3 // reload bias ptr | |||||
| dup v16.4s, wzr | dup v16.4s, wzr | ||||
| dup v17.4s, wzr | dup v17.4s, wzr | ||||
| dup v18.4s, wzr | dup v18.4s, wzr | ||||
| @@ -96,10 +97,10 @@ L2: | |||||
| dup v30.4s, wzr | dup v30.4s, wzr | ||||
| dup v31.4s, wzr | dup v31.4s, wzr | ||||
| OptLoopMul4: | |||||
| cmp w13, #4 | cmp w13, #4 | ||||
| blt CommLoopMul | blt CommLoopMul | ||||
| OptLoopMul4: | |||||
| ld1 {v0.4s, v1.4s}, [x12], #32 | ld1 {v0.4s, v1.4s}, [x12], #32 | ||||
| ld1 {v8.4s, v9.4s}, [x16], #32 | ld1 {v8.4s, v9.4s}, [x16], #32 | ||||
| fmla v16.4s, v8.4s, v0.s[0] | fmla v16.4s, v8.4s, v0.s[0] | ||||
| @@ -172,13 +173,14 @@ OptLoopMul4: | |||||
| fmla v29.4s, v15.4s, v7.s[2] | fmla v29.4s, v15.4s, v7.s[2] | ||||
| fmla v30.4s, v14.4s, v7.s[3] | fmla v30.4s, v14.4s, v7.s[3] | ||||
| fmla v31.4s, v15.4s, v7.s[3] | fmla v31.4s, v15.4s, v7.s[3] | ||||
| subs w13, w13, #4 | |||||
| b OptLoopMul4 | |||||
| CommLoopMul: | |||||
| cmp w13, #1 | |||||
| blt Bias | |||||
| sub w13, w13, #4 | |||||
| cmp w13, #0 | |||||
| ble Bias | |||||
| cmp w13, #4 | |||||
| bge OptLoopMul4 | |||||
| CommLoopMul: | |||||
| ld1 {v0.4s, v1.4s}, [x12], #32 | ld1 {v0.4s, v1.4s}, [x12], #32 | ||||
| ld1 {v2.4s, v3.4s}, [x16], #32 | ld1 {v2.4s, v3.4s}, [x16], #32 | ||||
| fmla v16.4s, v2.4s, v0.s[0] | fmla v16.4s, v2.4s, v0.s[0] | ||||
| @@ -197,8 +199,9 @@ CommLoopMul: | |||||
| fmla v29.4s, v3.4s, v1.s[2] | fmla v29.4s, v3.4s, v1.s[2] | ||||
| fmla v30.4s, v2.4s, v1.s[3] | fmla v30.4s, v2.4s, v1.s[3] | ||||
| fmla v31.4s, v3.4s, v1.s[3] | fmla v31.4s, v3.4s, v1.s[3] | ||||
| subs w13, w13, #1 | subs w13, w13, #1 | ||||
| b CommLoopMul | |||||
| bgt CommLoopMul | |||||
| Bias: | Bias: | ||||
| cbz x11, Activation | cbz x11, Activation | ||||
| @@ -226,7 +229,8 @@ Activation: | |||||
| beq Relu6 | beq Relu6 | ||||
| cmp w4, #1 | cmp w4, #1 | ||||
| beq Relu | beq Relu | ||||
| b TransToOut | |||||
| b Write | |||||
| Relu6: | Relu6: | ||||
| mov w8, #6 | mov w8, #6 | ||||
| dup v15.4s, w8 | dup v15.4s, w8 | ||||
| @@ -247,6 +251,7 @@ Relu6: | |||||
| fmin v29.4s, v29.4s, v15.4s | fmin v29.4s, v29.4s, v15.4s | ||||
| fmin v30.4s, v30.4s, v15.4s | fmin v30.4s, v30.4s, v15.4s | ||||
| fmin v31.4s, v31.4s, v15.4s | fmin v31.4s, v31.4s, v15.4s | ||||
| Relu: | Relu: | ||||
| dup v14.4s, wzr | dup v14.4s, wzr | ||||
| fmax v16.4s, v16.4s, v14.4s | fmax v16.4s, v16.4s, v14.4s | ||||
| @@ -266,7 +271,317 @@ Relu: | |||||
| fmax v30.4s, v30.4s, v14.4s | fmax v30.4s, v30.4s, v14.4s | ||||
| fmax v31.4s, v31.4s, v14.4s | fmax v31.4s, v31.4s, v14.4s | ||||
| TransToOut: | |||||
| Write: | |||||
| ldrb w13, [sp, #8] | |||||
| cbz w13, WriteC8 | |||||
| cmp w7, #1 | |||||
| beq Write1 | |||||
| cmp w7, #2 | |||||
| beq Write2 | |||||
| cmp w7, #3 | |||||
| beq Write3 | |||||
| cmp w7, #4 | |||||
| beq Write4 | |||||
| cmp w7, #5 | |||||
| beq Write5 | |||||
| cmp w7, #6 | |||||
| beq Write6 | |||||
| cmp w7, #7 | |||||
| beq Write7 | |||||
| b Write8 | |||||
| Write1: | |||||
| str s16, [x18] | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| str s18, [x18] | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| str s20, [x18] | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| str s22, [x18] | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| str s24, [x18] | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| str s26, [x18] | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| str s28, [x18] | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| str s30, [x18] | |||||
| add x18, x18, x17 | |||||
| b WriteEnd | |||||
| Write2: | |||||
| dup s17, v16.s[1] | |||||
| stp s16, s17, [x18] | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| dup s19, v18.s[1] | |||||
| stp s18, s19, [x18] | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| dup s21, v20.s[1] | |||||
| stp s20, s21, [x18] | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| dup s23, v22.s[1] | |||||
| stp s22, s23, [x18] | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| dup s25, v24.s[1] | |||||
| stp s24, s25, [x18] | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| dup s27, v26.s[1] | |||||
| stp s26, s27, [x18] | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| dup s29, v28.s[1] | |||||
| stp s28, s29, [x18] | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| add x18, x18, x17 | |||||
| dup s31, v30.s[1] | |||||
| stp s30, s31, [x18] | |||||
| add x18, x18, x17 | |||||
| b WriteEnd | |||||
| Write3: | |||||
| add x13, x18, #8 | |||||
| dup s17, v16.s[1] | |||||
| stp s16, s17, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v16.s}[2], [x13], x17 | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| dup s19, v18.s[1] | |||||
| stp s18, s19, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v18.s}[2], [x13], x17 | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| dup s21, v20.s[1] | |||||
| stp s20, s21, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v20.s}[2], [x13], x17 | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| dup s23, v22.s[1] | |||||
| stp s22, s23, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v22.s}[2], [x13], x17 | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| dup s25, v24.s[1] | |||||
| stp s24, s25, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v24.s}[2], [x13], x17 | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| dup s27, v26.s[1] | |||||
| stp s26, s27, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v26.s}[2], [x13], x17 | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| dup s29, v28.s[1] | |||||
| stp s28, s29, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v28.s}[2], [x13], x17 | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| dup s31, v30.s[1] | |||||
| stp s30, s31, [x18] | |||||
| add x18, x18, x17 | |||||
| st1 {v30.s}[2], [x13] | |||||
| b WriteEnd | |||||
| Write4: | |||||
| st1 {v16.4s}, [x18], x17 | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| st1 {v18.4s}, [x18], x17 | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| st1 {v20.4s}, [x18], x17 | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| st1 {v22.4s}, [x18], x17 | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| st1 {v24.4s}, [x18], x17 | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| st1 {v26.4s}, [x18], x17 | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| st1 {v28.4s}, [x18], x17 | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| st1 {v30.4s}, [x18], x17 | |||||
| b WriteEnd | |||||
| Write5: | |||||
| add x13, x18, #16 | |||||
| st1 {v16.4s}, [x18], x17 | |||||
| str s17, [x13] | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v18.4s}, [x18], x17 | |||||
| str s19, [x13] | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v20.4s}, [x18], x17 | |||||
| str s21, [x13] | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v22.4s}, [x18], x17 | |||||
| str s23, [x13] | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v24.4s}, [x18], x17 | |||||
| str s25, [x13] | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v26.4s}, [x18], x17 | |||||
| str s27, [x13] | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v28.4s}, [x18], x17 | |||||
| str s29, [x13] | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v30.4s}, [x18], x17 | |||||
| str s31, [x13] | |||||
| b WriteEnd | |||||
| Write6: | |||||
| add x13, x18, #16 | |||||
| st1 {v16.4s}, [x18], x17 | |||||
| dup s16, v17.s[1] | |||||
| stp s17, s16, [x13] | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v18.4s}, [x18], x17 | |||||
| dup s18, v19.s[1] | |||||
| stp s19, s18, [x13] | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v20.4s}, [x18], x17 | |||||
| dup s20, v21.s[1] | |||||
| stp s21, s20, [x13] | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v22.4s}, [x18], x17 | |||||
| dup s22, v23.s[1] | |||||
| stp s23, s22, [x13] | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v24.4s}, [x18], x17 | |||||
| dup s24, v25.s[1] | |||||
| stp s25, s24, [x13] | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v26.4s}, [x18], x17 | |||||
| dup s26, v27.s[1] | |||||
| stp s27, s26, [x13] | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v28.4s}, [x18], x17 | |||||
| dup s28, v29.s[1] | |||||
| stp s29, s28, [x13] | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| add x13, x13, x17 | |||||
| st1 {v30.4s}, [x18], x17 | |||||
| dup s30, v31.s[1] | |||||
| stp s31, s30, [x13] | |||||
| b WriteEnd | |||||
| Write7: | |||||
| add x13, x18, #16 | |||||
| add x16, x18, #24 | |||||
| st1 {v16.4s}, [x18], x17 | |||||
| dup s16, v17.s[1] | |||||
| stp s17, s16, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v17.s}[2], [x16], x17 | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| st1 {v18.4s}, [x18], x17 | |||||
| dup s18, v19.s[1] | |||||
| stp s19, s18, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v19.s}[2], [x16], x17 | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| st1 {v20.4s}, [x18], x17 | |||||
| dup s20, v21.s[1] | |||||
| stp s21, s20, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v21.s}[2], [x16], x17 | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| st1 {v22.4s}, [x18], x17 | |||||
| dup s22, v23.s[1] | |||||
| stp s23, s22, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v23.s}[2], [x16], x17 | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| st1 {v24.4s}, [x18], x17 | |||||
| dup s24, v25.s[1] | |||||
| stp s25, s24, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v25.s}[2], [x16], x17 | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| st1 {v26.4s}, [x18], x17 | |||||
| dup s26, v27.s[1] | |||||
| stp s27, s26, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v27.s}[2], [x16], x17 | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| st1 {v28.4s}, [x18], x17 | |||||
| dup s28, v29.s[1] | |||||
| stp s29, s28, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v29.s}[2], [x16], x17 | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| st1 {v30.4s}, [x18], x17 | |||||
| dup s30, v31.s[1] | |||||
| stp s31, s30, [x13] | |||||
| add x13, x13, x17 | |||||
| st1 {v31.s}[2], [x16], x17 | |||||
| b WriteEnd | |||||
| WriteC8: | |||||
| st1 {v16.4s}, [x2], #16 | st1 {v16.4s}, [x2], #16 | ||||
| st1 {v17.4s}, [x2], #16 | st1 {v17.4s}, [x2], #16 | ||||
| st1 {v18.4s}, [x2], #16 | st1 {v18.4s}, [x2], #16 | ||||
| @@ -283,19 +598,48 @@ TransToOut: | |||||
| st1 {v29.4s}, [x2], #16 | st1 {v29.4s}, [x2], #16 | ||||
| st1 {v30.4s}, [x2], #16 | st1 {v30.4s}, [x2], #16 | ||||
| st1 {v31.4s}, [x2], #16 | st1 {v31.4s}, [x2], #16 | ||||
| b WriteEnd | |||||
| Write8: | |||||
| st1 {v16.4s, v17.4s}, [x18], x17 | |||||
| cmp w10, #1 | |||||
| beq WriteEnd | |||||
| st1 {v18.4s, v19.4s}, [x18], x17 | |||||
| cmp w10, #2 | |||||
| beq WriteEnd | |||||
| st1 {v20.4s, v21.4s}, [x18], x17 | |||||
| cmp w10, #3 | |||||
| beq WriteEnd | |||||
| st1 {v22.4s, v23.4s}, [x18], x17 | |||||
| cmp w10, #4 | |||||
| beq WriteEnd | |||||
| st1 {v24.4s, v25.4s}, [x18], x17 | |||||
| cmp w10, #5 | |||||
| beq WriteEnd | |||||
| st1 {v26.4s, v27.4s}, [x18], x17 | |||||
| cmp w10, #6 | |||||
| beq WriteEnd | |||||
| st1 {v28.4s, v29.4s}, [x18], x17 | |||||
| cmp w10, #7 | |||||
| beq WriteEnd | |||||
| st1 {v30.4s, v31.4s}, [x18], x17 | |||||
| add w10, w10, #8 // lm row offset + 8 | |||||
| b L2 | |||||
| WriteEnd: | |||||
| subs w10, w10, #8 // lhs row - 8 | |||||
| bgt L2 | |||||
| End2: | End2: | ||||
| add w9, w9, #8 // rm col offset + 8 | |||||
| add x1, x1, x15 // rm ptr + stride | |||||
| add x3, x3, x18 // bias ptr + stride | |||||
| b L1 | |||||
| subs w7, w7, #8 // rhs col - 8 | |||||
| add x1, x1, x15 // rhs ptr + stride | |||||
| add x3, x3, #32 // bias ptr + stride | |||||
| ldrb w13, [sp, #8] | |||||
| cbz w13, NoDstStep | |||||
| add x2, x2, #32 // dst ptr + stride | |||||
| NoDstStep: | |||||
| bgt L1 | |||||
| End1: | End1: | ||||
| sub sp, sp, #128 | sub sp, sp, #128 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | ||||
| ret | ret | ||||
| #endif | |||||
| #endif | |||||
| @@ -221,34 +221,57 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col | |||||
| return; | return; | ||||
| } | } | ||||
| void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, | |||||
| int col_8_) { | |||||
| /* col8-major * row8-major => col8x8-major */ | |||||
| for (int row = 0; row < row_8_; row++) { | |||||
| for (int col = 0; col < col_8_; col++) { | |||||
| int r8div = row / 8, r8mod = row % 8; | |||||
| int c8div = col / 8, c8mod = col % 8; | |||||
| size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod; | |||||
| float value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||||
| int col, int stride, bool write_nhwc) { | |||||
| if (write_nhwc) { | |||||
| /* col8-major * row8-major => col-major */ | |||||
| for (int r = 0; r < row; r++) { | |||||
| for (int c = 0; c < col; c++) { | |||||
| int r8div = r / 8, r8mod = r % 8; | |||||
| int c8div = c / 8, c8mod = c % 8; | |||||
| size_t ci = r * stride + c; | |||||
| float value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| if (bias != NULL) value += bias[c]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| /* col8-major * row8-major => col8x8-major */ | |||||
| int col_8 = UP_ROUND(col, C8NUM); | |||||
| int row_8 = UP_ROUND(row, C8NUM); | |||||
| for (int r = 0; r < row_8; r++) { | |||||
| for (int c = 0; c < col_8; c++) { | |||||
| int r8div = r / 8, r8mod = r % 8; | |||||
| int c8div = c / 8, c8mod = c % 8; | |||||
| size_t ci = c8div * row_8 * 8 + r * 8 + c8mod; | |||||
| float value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| if (bias != NULL) value += bias[c]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| dst[ci] = value; | |||||
| } | } | ||||
| if (bias != NULL) value += bias[col]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| c[ci] = value; | |||||
| } | } | ||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, | |||||
| int col_8_) { | |||||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, | |||||
| int stride, bool write_nhwc) { | |||||
| #ifdef __aarch64__ | #ifdef __aarch64__ | ||||
| MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_); | |||||
| MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); | |||||
| #else | #else | ||||
| MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_); | |||||
| MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -26,13 +26,14 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col); | |||||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, | |||||
| int stride, bool write_nhwc); | |||||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | ||||
| #ifdef __aarch64__ | #ifdef __aarch64__ | ||||
| void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | ||||
| int col); | |||||
| int col, size_t stride, bool write_nhwc); | |||||
| #endif | #endif | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -370,26 +370,35 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) { | |||||
| conv1x1->Run(); | conv1x1->Run(); | ||||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | ||||
| /* running warm up */ | |||||
| for (int i = 0; i < 0; i++) { | |||||
| conv1x1->Run(); | |||||
| auto ptr = reinterpret_cast<float *>(outputs_[0]->Data()); | |||||
| bool first = true; | |||||
| for (int i = 0; i < total_size; i++) { | |||||
| if (fabs(ptr[i] - correct[i]) > 0.001 && first) { | |||||
| printf("%d %f %f\n", i, ptr[i], correct[i]); | |||||
| first = false; | |||||
| } | |||||
| } | } | ||||
| /* running time cost */ | |||||
| int loop_count = 1; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| conv1x1->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| uint64_t time_avg = cost / loop_count; | |||||
| printf("1x1 average time : %f ms\n", time_avg / 1000.0f); | |||||
| delete conv_param; | |||||
| delete conv1x1; | |||||
| for (auto t : inputs_) delete t; | |||||
| for (auto t : outputs_) delete t; | |||||
| free(correct); | |||||
| // /* running warm up */ | |||||
| // for (int i = 0; i < 0; i++) { | |||||
| // conv1x1->Run(); | |||||
| // } | |||||
| // | |||||
| // /* running time cost */ | |||||
| // int loop_count = 1; | |||||
| // auto time_start = mindspore::lite::GetTimeUs(); | |||||
| // for (int i = 0; i < loop_count; i++) { | |||||
| // conv1x1->Run(); | |||||
| // } | |||||
| // auto time_end = mindspore::lite::GetTimeUs(); | |||||
| // auto cost = time_end - time_start; | |||||
| // uint64_t time_avg = cost / loop_count; | |||||
| // printf("1x1 average time : %f ms\n", time_avg / 1000.0f); | |||||
| // | |||||
| // delete conv_param; | |||||
| // delete conv1x1; | |||||
| // for (auto t : inputs_) delete t; | |||||
| // for (auto t : outputs_) delete t; | |||||
| // free(correct); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||