diff --git a/src/layer/arm/convolution_1x1.h b/src/layer/arm/convolution_1x1.h index cafdf55b6..226326d2b 100644 --- a/src/layer/arm/convolution_1x1.h +++ b/src/layer/arm/convolution_1x1.h @@ -16,6 +16,1242 @@ #include #endif // __ARM_NEON +#if !__aarch64__ +// TODO HACK drop them +static inline float32x4_t vfmaq_f32(float32x4_t _s, float32x4_t _a, float32x4_t _b) +{ + return vmlaq_f32(_s, _a, _b); +} + +static inline float32x4_t vfmaq_laneq_f32(float32x4_t _s, float32x4_t _a, float32x4_t _b, int i) +{ + if (i == 0) return vmlaq_lane_f32(_s, _a, vget_low_f32(_b), 0); + if (i == 1) return vmlaq_lane_f32(_s, _a, vget_low_f32(_b), 1); + if (i == 2) return vmlaq_lane_f32(_s, _a, vget_high_f32(_b), 0); + if (i == 3) return vmlaq_lane_f32(_s, _a, vget_high_f32(_b), 1); +} + +static inline float vaddvq_f32(float32x4_t _s) +{ + float32x2_t _ss = vadd_f32(vget_low_f32(_s), vget_high_f32(_s)); + float32x2_t _ss2 = vpadd_f32(_ss, _ss); + return vget_lane_f32(_ss2, 0); +} +#endif + +static void conv1x1s1_sgemm_transform_kernel_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch) +{ + const float* kernel = _kernel; + + // interleave +#if __aarch64__ + kernel_tm.create(4*8, inch/4 + inch%4, outch/8 + (outch%8)/4 + outch%4); +#else + kernel_tm.create(4*4, inch/4 + inch%4, outch/4 + outch%4); +#endif // __aarch64__ + + int p = 0; +#if __aarch64__ + for (; p+7= 16 && num_output >= 16) + use_sgemm1x1 = true; + } +#endif // __aarch64__ + return 0; } @@ -57,6 +69,12 @@ int Convolution_arm::load_model(const ModelBin& mb) conv3x3s1_winograd64_transform_kernel_neon5(weight_data, weight_3x3_winograd64_data, num_input, num_output); } + if (use_sgemm1x1) + { + int num_input = weight_data_size / num_output; + conv1x1s1_sgemm_transform_kernel_neon(weight_data, weight_1x1_sgemm_data, num_input, num_output); + } + return 0; } @@ -297,6 +315,11 @@ int Convolution_arm::forward(const Mat& bottom_blob, Mat& top_blob) const // conv3x3s1_winograd64_neon4(bottom_blob_bordered, top_blob, weight_3x3_winograd64_data, bias_data); conv3x3s1_winograd64_neon5(bottom_blob_bordered, top_blob, weight_3x3_winograd64_data, bias_data); } + else if (use_sgemm1x1 && w <= 120 && h <= 120) + { + // TODO assume more proper condition + conv1x1s1_sgemm_neon(bottom_blob_bordered, top_blob, weight_1x1_sgemm_data, bias_data); + } else conv(bottom_blob_bordered, top_blob, weight_data, bias_data); diff --git a/src/layer/arm/convolution_arm.h b/src/layer/arm/convolution_arm.h index ede6b3fad..6a47fff51 100644 --- a/src/layer/arm/convolution_arm.h +++ b/src/layer/arm/convolution_arm.h @@ -33,7 +33,9 @@ public: public: bool use_winograd3x3; + bool use_sgemm1x1; Mat weight_3x3_winograd64_data; + Mat weight_1x1_sgemm_data; }; } // namespace ncnn