| @@ -31,10 +31,6 @@ Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { | |||
| free(pack_input_); | |||
| pack_input_ = nullptr; | |||
| } | |||
| if (pack_output_ != nullptr) { | |||
| free(pack_output_); | |||
| pack_output_ = nullptr; | |||
| } | |||
| if (pre_trans_input_ && input_ptr_ != nullptr) { | |||
| free(input_ptr_); | |||
| input_ptr_ = nullptr; | |||
| @@ -112,13 +108,6 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)); | |||
| pack_output_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float))); | |||
| if (pack_output_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| @@ -157,7 +146,7 @@ int Convolution1x1CPUKernel::Init() { | |||
| } | |||
| int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | |||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_); | |||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| @@ -165,23 +154,12 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | |||
| auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id; | |||
| MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, | |||
| pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_, | |||
| matmul_param_->deep_, matmul_param_->row_8_, cur_oc); | |||
| output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, | |||
| matmul_param_->row_, cur_oc, matmul_param_->col_, true); | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1CPUKernel::DoConv1x1Post(int task_id) { | |||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| float *src = pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_; | |||
| float *dst = output_ptr_ + task_id * thread_stride_; | |||
| Row8x8Major2RowMajor(src, dst, matmul_param_->row_, cur_oc, matmul_param_->col_); | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); | |||
| auto error_code = conv1x1->DoConv1x1(task_id); | |||
| @@ -192,12 +170,6 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1Post(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata); | |||
| conv1x1->DoConv1x1Post(task_id); | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1CPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| @@ -216,8 +188,6 @@ int Convolution1x1CPUKernel::Run() { | |||
| MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| LiteBackendParallelLaunch(Convolution1x1Post, this, thread_count_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -46,7 +46,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||
| public: | |||
| int DoConv1x1(int task_id); | |||
| int DoConv1x1Post(int task_id); | |||
| private: | |||
| int InitConv1x1Param(); | |||
| @@ -61,7 +60,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||
| int thread_stride_ = 0; | |||
| float *weight_ptr_ = nullptr; | |||
| float *pack_input_ = nullptr; | |||
| float *pack_output_ = nullptr; | |||
| float *input_ptr_ = nullptr; | |||
| float *output_ptr_ = nullptr; | |||
| }; | |||
| @@ -152,7 +152,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { | |||
| MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | |||
| tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No, | |||
| matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_); | |||
| matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, matmul_param_->col_, false); | |||
| return RET_OK; | |||
| } | |||
| @@ -104,7 +104,7 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) { | |||
| MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, | |||
| c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_, | |||
| bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_, | |||
| cur_oc * 8); | |||
| cur_oc * 8, 0, false); | |||
| return RET_OK; | |||
| } | |||
| @@ -77,7 +77,7 @@ int MatmulCPUKernel::RunImpl(int task_id) { | |||
| } | |||
| auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; | |||
| auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; | |||
| MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8); | |||
| MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false); | |||
| return RET_OK; | |||
| } | |||
| @@ -640,7 +640,7 @@ IndirectGemmStart: | |||
| add x15, x15, x7 | |||
| str s30, [x15] | |||
| add x0, x0, #4 | |||
| b WriteEnd | |||
| b WriteEndHalf | |||
| Write2: | |||
| dup s17, v16.s[1] | |||
| stp s16, s17, [x15] | |||
| @@ -666,7 +666,7 @@ IndirectGemmStart: | |||
| dup s31, v30.s[1] | |||
| stp s30, s31, [x15] | |||
| add x0, x0, #8 | |||
| b WriteEnd | |||
| b WriteEndHalf | |||
| Write3: | |||
| add x17, x15, #8 | |||
| dup s17, v16.s[1] | |||
| @@ -27,7 +27,7 @@ | |||
| // accumulators 8x8 block | |||
| // | |||
| /////////////////////////////////////////////////////////////////////////////// | |||
| //OptLoopMul4 RM 1x8 block | |||
| //OptLoopMul4 RM 4x8 block | |||
| // /--------------------------------------------\ | |||
| // |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | | |||
| // |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| | |||
| @@ -46,7 +46,8 @@ | |||
| // accumulators 8x8 block | |||
| ///////////////////////////////////////////////////////////////////////////////// | |||
| // | |||
| // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col) | |||
| // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||
| // int row, int col, int stride, bool write_nhwc) | |||
| // x0: a | |||
| // x1: b | |||
| // x2: c | |||
| @@ -55,30 +56,30 @@ | |||
| // w5: depth | |||
| // w6: row | |||
| // w7: col | |||
| // w17: stride | |||
| // w13: writeC8 | |||
| MatmulFloatNeon64: | |||
| sub sp, sp, #128 | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| mov w9, #0 // rm col offset | |||
| mov w10, #0 // lm row offset | |||
| mov w18, #32 // sizeof(float)*8 | |||
| mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth | |||
| mov x11, x3 // bias flag | |||
| mov w18, #32 // sizeof(float) * 8 | |||
| mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth | |||
| mov x11, x3 // bias flag | |||
| mov x18, #4 | |||
| ldr x17, [sp] | |||
| mul x17, x17, x18 | |||
| L1: | |||
| cmp w9, w7 | |||
| beq End1 | |||
| mov w10, w6 // reload lhs row | |||
| mov x12, x0 // reload lhs ptr | |||
| mov x18, x2 // reload dst ptr | |||
| mov w10, #0 // reset lm row offset | |||
| mov x12, x0 // reload lm ptr | |||
| L2: | |||
| cmp w10, w6 | |||
| beq End2 | |||
| mov x16, x1 // reload rm ptr | |||
| mov w13, w5 // reload depth | |||
| mov x14, x3 // reload bias ptr | |||
| mov x16, x1 // reload rhs ptr | |||
| mov w13, w5 // reload depth | |||
| mov x14, x3 // reload bias ptr | |||
| dup v16.4s, wzr | |||
| dup v17.4s, wzr | |||
| dup v18.4s, wzr | |||
| @@ -96,10 +97,10 @@ L2: | |||
| dup v30.4s, wzr | |||
| dup v31.4s, wzr | |||
| OptLoopMul4: | |||
| cmp w13, #4 | |||
| blt CommLoopMul | |||
| OptLoopMul4: | |||
| ld1 {v0.4s, v1.4s}, [x12], #32 | |||
| ld1 {v8.4s, v9.4s}, [x16], #32 | |||
| fmla v16.4s, v8.4s, v0.s[0] | |||
| @@ -172,13 +173,14 @@ OptLoopMul4: | |||
| fmla v29.4s, v15.4s, v7.s[2] | |||
| fmla v30.4s, v14.4s, v7.s[3] | |||
| fmla v31.4s, v15.4s, v7.s[3] | |||
| subs w13, w13, #4 | |||
| b OptLoopMul4 | |||
| CommLoopMul: | |||
| cmp w13, #1 | |||
| blt Bias | |||
| sub w13, w13, #4 | |||
| cmp w13, #0 | |||
| ble Bias | |||
| cmp w13, #4 | |||
| bge OptLoopMul4 | |||
| CommLoopMul: | |||
| ld1 {v0.4s, v1.4s}, [x12], #32 | |||
| ld1 {v2.4s, v3.4s}, [x16], #32 | |||
| fmla v16.4s, v2.4s, v0.s[0] | |||
| @@ -197,8 +199,9 @@ CommLoopMul: | |||
| fmla v29.4s, v3.4s, v1.s[2] | |||
| fmla v30.4s, v2.4s, v1.s[3] | |||
| fmla v31.4s, v3.4s, v1.s[3] | |||
| subs w13, w13, #1 | |||
| b CommLoopMul | |||
| bgt CommLoopMul | |||
| Bias: | |||
| cbz x11, Activation | |||
| @@ -226,7 +229,8 @@ Activation: | |||
| beq Relu6 | |||
| cmp w4, #1 | |||
| beq Relu | |||
| b TransToOut | |||
| b Write | |||
| Relu6: | |||
| mov w8, #6 | |||
| dup v15.4s, w8 | |||
| @@ -247,6 +251,7 @@ Relu6: | |||
| fmin v29.4s, v29.4s, v15.4s | |||
| fmin v30.4s, v30.4s, v15.4s | |||
| fmin v31.4s, v31.4s, v15.4s | |||
| Relu: | |||
| dup v14.4s, wzr | |||
| fmax v16.4s, v16.4s, v14.4s | |||
| @@ -266,7 +271,317 @@ Relu: | |||
| fmax v30.4s, v30.4s, v14.4s | |||
| fmax v31.4s, v31.4s, v14.4s | |||
| TransToOut: | |||
| Write: | |||
| ldrb w13, [sp, #8] | |||
| cbz w13, WriteC8 | |||
| cmp w7, #1 | |||
| beq Write1 | |||
| cmp w7, #2 | |||
| beq Write2 | |||
| cmp w7, #3 | |||
| beq Write3 | |||
| cmp w7, #4 | |||
| beq Write4 | |||
| cmp w7, #5 | |||
| beq Write5 | |||
| cmp w7, #6 | |||
| beq Write6 | |||
| cmp w7, #7 | |||
| beq Write7 | |||
| b Write8 | |||
| Write1: | |||
| str s16, [x18] | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| str s18, [x18] | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| str s20, [x18] | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| str s22, [x18] | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| str s24, [x18] | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| str s26, [x18] | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| str s28, [x18] | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| str s30, [x18] | |||
| add x18, x18, x17 | |||
| b WriteEnd | |||
| Write2: | |||
| dup s17, v16.s[1] | |||
| stp s16, s17, [x18] | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| dup s19, v18.s[1] | |||
| stp s18, s19, [x18] | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| dup s21, v20.s[1] | |||
| stp s20, s21, [x18] | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| dup s23, v22.s[1] | |||
| stp s22, s23, [x18] | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| dup s25, v24.s[1] | |||
| stp s24, s25, [x18] | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| dup s27, v26.s[1] | |||
| stp s26, s27, [x18] | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| dup s29, v28.s[1] | |||
| stp s28, s29, [x18] | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| add x18, x18, x17 | |||
| dup s31, v30.s[1] | |||
| stp s30, s31, [x18] | |||
| add x18, x18, x17 | |||
| b WriteEnd | |||
| Write3: | |||
| add x13, x18, #8 | |||
| dup s17, v16.s[1] | |||
| stp s16, s17, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v16.s}[2], [x13], x17 | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| dup s19, v18.s[1] | |||
| stp s18, s19, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v18.s}[2], [x13], x17 | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| dup s21, v20.s[1] | |||
| stp s20, s21, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v20.s}[2], [x13], x17 | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| dup s23, v22.s[1] | |||
| stp s22, s23, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v22.s}[2], [x13], x17 | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| dup s25, v24.s[1] | |||
| stp s24, s25, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v24.s}[2], [x13], x17 | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| dup s27, v26.s[1] | |||
| stp s26, s27, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v26.s}[2], [x13], x17 | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| dup s29, v28.s[1] | |||
| stp s28, s29, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v28.s}[2], [x13], x17 | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| dup s31, v30.s[1] | |||
| stp s30, s31, [x18] | |||
| add x18, x18, x17 | |||
| st1 {v30.s}[2], [x13] | |||
| b WriteEnd | |||
| Write4: | |||
| st1 {v16.4s}, [x18], x17 | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| st1 {v18.4s}, [x18], x17 | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| st1 {v20.4s}, [x18], x17 | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| st1 {v22.4s}, [x18], x17 | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| st1 {v24.4s}, [x18], x17 | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| st1 {v26.4s}, [x18], x17 | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| st1 {v28.4s}, [x18], x17 | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| st1 {v30.4s}, [x18], x17 | |||
| b WriteEnd | |||
| Write5: | |||
| add x13, x18, #16 | |||
| st1 {v16.4s}, [x18], x17 | |||
| str s17, [x13] | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v18.4s}, [x18], x17 | |||
| str s19, [x13] | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v20.4s}, [x18], x17 | |||
| str s21, [x13] | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v22.4s}, [x18], x17 | |||
| str s23, [x13] | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v24.4s}, [x18], x17 | |||
| str s25, [x13] | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v26.4s}, [x18], x17 | |||
| str s27, [x13] | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v28.4s}, [x18], x17 | |||
| str s29, [x13] | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v30.4s}, [x18], x17 | |||
| str s31, [x13] | |||
| b WriteEnd | |||
| Write6: | |||
| add x13, x18, #16 | |||
| st1 {v16.4s}, [x18], x17 | |||
| dup s16, v17.s[1] | |||
| stp s17, s16, [x13] | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v18.4s}, [x18], x17 | |||
| dup s18, v19.s[1] | |||
| stp s19, s18, [x13] | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v20.4s}, [x18], x17 | |||
| dup s20, v21.s[1] | |||
| stp s21, s20, [x13] | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v22.4s}, [x18], x17 | |||
| dup s22, v23.s[1] | |||
| stp s23, s22, [x13] | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v24.4s}, [x18], x17 | |||
| dup s24, v25.s[1] | |||
| stp s25, s24, [x13] | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v26.4s}, [x18], x17 | |||
| dup s26, v27.s[1] | |||
| stp s27, s26, [x13] | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v28.4s}, [x18], x17 | |||
| dup s28, v29.s[1] | |||
| stp s29, s28, [x13] | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| add x13, x13, x17 | |||
| st1 {v30.4s}, [x18], x17 | |||
| dup s30, v31.s[1] | |||
| stp s31, s30, [x13] | |||
| b WriteEnd | |||
| Write7: | |||
| add x13, x18, #16 | |||
| add x16, x18, #24 | |||
| st1 {v16.4s}, [x18], x17 | |||
| dup s16, v17.s[1] | |||
| stp s17, s16, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v17.s}[2], [x16], x17 | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| st1 {v18.4s}, [x18], x17 | |||
| dup s18, v19.s[1] | |||
| stp s19, s18, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v19.s}[2], [x16], x17 | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| st1 {v20.4s}, [x18], x17 | |||
| dup s20, v21.s[1] | |||
| stp s21, s20, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v21.s}[2], [x16], x17 | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| st1 {v22.4s}, [x18], x17 | |||
| dup s22, v23.s[1] | |||
| stp s23, s22, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v23.s}[2], [x16], x17 | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| st1 {v24.4s}, [x18], x17 | |||
| dup s24, v25.s[1] | |||
| stp s25, s24, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v25.s}[2], [x16], x17 | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| st1 {v26.4s}, [x18], x17 | |||
| dup s26, v27.s[1] | |||
| stp s27, s26, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v27.s}[2], [x16], x17 | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| st1 {v28.4s}, [x18], x17 | |||
| dup s28, v29.s[1] | |||
| stp s29, s28, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v29.s}[2], [x16], x17 | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| st1 {v30.4s}, [x18], x17 | |||
| dup s30, v31.s[1] | |||
| stp s31, s30, [x13] | |||
| add x13, x13, x17 | |||
| st1 {v31.s}[2], [x16], x17 | |||
| b WriteEnd | |||
| WriteC8: | |||
| st1 {v16.4s}, [x2], #16 | |||
| st1 {v17.4s}, [x2], #16 | |||
| st1 {v18.4s}, [x2], #16 | |||
| @@ -283,19 +598,48 @@ TransToOut: | |||
| st1 {v29.4s}, [x2], #16 | |||
| st1 {v30.4s}, [x2], #16 | |||
| st1 {v31.4s}, [x2], #16 | |||
| b WriteEnd | |||
| Write8: | |||
| st1 {v16.4s, v17.4s}, [x18], x17 | |||
| cmp w10, #1 | |||
| beq WriteEnd | |||
| st1 {v18.4s, v19.4s}, [x18], x17 | |||
| cmp w10, #2 | |||
| beq WriteEnd | |||
| st1 {v20.4s, v21.4s}, [x18], x17 | |||
| cmp w10, #3 | |||
| beq WriteEnd | |||
| st1 {v22.4s, v23.4s}, [x18], x17 | |||
| cmp w10, #4 | |||
| beq WriteEnd | |||
| st1 {v24.4s, v25.4s}, [x18], x17 | |||
| cmp w10, #5 | |||
| beq WriteEnd | |||
| st1 {v26.4s, v27.4s}, [x18], x17 | |||
| cmp w10, #6 | |||
| beq WriteEnd | |||
| st1 {v28.4s, v29.4s}, [x18], x17 | |||
| cmp w10, #7 | |||
| beq WriteEnd | |||
| st1 {v30.4s, v31.4s}, [x18], x17 | |||
| add w10, w10, #8 // lm row offset + 8 | |||
| b L2 | |||
| WriteEnd: | |||
| subs w10, w10, #8 // lhs row - 8 | |||
| bgt L2 | |||
| End2: | |||
| add w9, w9, #8 // rm col offset + 8 | |||
| add x1, x1, x15 // rm ptr + stride | |||
| add x3, x3, x18 // bias ptr + stride | |||
| b L1 | |||
| subs w7, w7, #8 // rhs col - 8 | |||
| add x1, x1, x15 // rhs ptr + stride | |||
| add x3, x3, #32 // bias ptr + stride | |||
| ldrb w13, [sp, #8] | |||
| cbz w13, NoDstStep | |||
| add x2, x2, #32 // dst ptr + stride | |||
| NoDstStep: | |||
| bgt L1 | |||
| End1: | |||
| sub sp, sp, #128 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| ret | |||
| #endif | |||
| #endif | |||
| @@ -221,34 +221,57 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col | |||
| return; | |||
| } | |||
| void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, | |||
| int col_8_) { | |||
| /* col8-major * row8-major => col8x8-major */ | |||
| for (int row = 0; row < row_8_; row++) { | |||
| for (int col = 0; col < col_8_; col++) { | |||
| int r8div = row / 8, r8mod = row % 8; | |||
| int c8div = col / 8, c8mod = col % 8; | |||
| size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod; | |||
| float value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||
| int col, int stride, bool write_nhwc) { | |||
| if (write_nhwc) { | |||
| /* col8-major * row8-major => col-major */ | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| int r8div = r / 8, r8mod = r % 8; | |||
| int c8div = c / 8, c8mod = c % 8; | |||
| size_t ci = r * stride + c; | |||
| float value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[c]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } else { | |||
| /* col8-major * row8-major => col8x8-major */ | |||
| int col_8 = UP_ROUND(col, C8NUM); | |||
| int row_8 = UP_ROUND(row, C8NUM); | |||
| for (int r = 0; r < row_8; r++) { | |||
| for (int c = 0; c < col_8; c++) { | |||
| int r8div = r / 8, r8mod = r % 8; | |||
| int c8div = c / 8, c8mod = c % 8; | |||
| size_t ci = c8div * row_8 * 8 + r * 8 + c8mod; | |||
| float value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[c]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| dst[ci] = value; | |||
| } | |||
| if (bias != NULL) value += bias[col]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| c[ci] = value; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, | |||
| int col_8_) { | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, | |||
| int stride, bool write_nhwc) { | |||
| #ifdef __aarch64__ | |||
| MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_); | |||
| MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); | |||
| #else | |||
| MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_); | |||
| MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); | |||
| #endif | |||
| } | |||
| @@ -26,13 +26,14 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col); | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, | |||
| int stride, bool write_nhwc); | |||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | |||
| #ifdef __aarch64__ | |||
| void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col); | |||
| int col, size_t stride, bool write_nhwc); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -370,26 +370,35 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) { | |||
| conv1x1->Run(); | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| /* running warm up */ | |||
| for (int i = 0; i < 0; i++) { | |||
| conv1x1->Run(); | |||
| auto ptr = reinterpret_cast<float *>(outputs_[0]->Data()); | |||
| bool first = true; | |||
| for (int i = 0; i < total_size; i++) { | |||
| if (fabs(ptr[i] - correct[i]) > 0.001 && first) { | |||
| printf("%d %f %f\n", i, ptr[i], correct[i]); | |||
| first = false; | |||
| } | |||
| } | |||
| /* running time cost */ | |||
| int loop_count = 1; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| conv1x1->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| uint64_t time_avg = cost / loop_count; | |||
| printf("1x1 average time : %f ms\n", time_avg / 1000.0f); | |||
| delete conv_param; | |||
| delete conv1x1; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| free(correct); | |||
| // /* running warm up */ | |||
| // for (int i = 0; i < 0; i++) { | |||
| // conv1x1->Run(); | |||
| // } | |||
| // | |||
| // /* running time cost */ | |||
| // int loop_count = 1; | |||
| // auto time_start = mindspore::lite::GetTimeUs(); | |||
| // for (int i = 0; i < loop_count; i++) { | |||
| // conv1x1->Run(); | |||
| // } | |||
| // auto time_end = mindspore::lite::GetTimeUs(); | |||
| // auto cost = time_end - time_start; | |||
| // uint64_t time_avg = cost / loop_count; | |||
| // printf("1x1 average time : %f ms\n", time_avg / 1000.0f); | |||
| // | |||
| // delete conv_param; | |||
| // delete conv1x1; | |||
| // for (auto t : inputs_) delete t; | |||
| // for (auto t : outputs_) delete t; | |||
| // free(correct); | |||
| } | |||
| } // namespace mindspore | |||