* arm deconv matmul use gemm * reduce gemm armv7 register usestags/20230517
| @@ -42,6 +42,7 @@ protected: | |||
| public: | |||
| Layer* activation; | |||
| Layer* gemm; | |||
| Mat weight_data_tm; | |||
| @@ -41,26 +41,81 @@ int Deconvolution_arm::create_pipeline_fp16s(const Option& opt) | |||
| out_elempack = opt.use_fp16_arithmetic && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; | |||
| } | |||
| Mat weight_data_transposed(weight_data.w); | |||
| if (opt.use_fp16_arithmetic && opt.use_sgemm_convolution) | |||
| { | |||
| float* pt = weight_data_transposed; | |||
| const float* p = weight_data; | |||
| for (int i = 0; i < num_input * num_output; i++) | |||
| const int maxk = kernel_w * kernel_h; | |||
| gemm = ncnn::create_layer(ncnn::LayerType::Gemm); | |||
| ncnn::ParamDict pd; | |||
| pd.set(2, 1); // transA | |||
| pd.set(3, 0); // transB | |||
| pd.set(4, 1); // constantA | |||
| pd.set(5, 0); // constantB | |||
| pd.set(6, 1); // constantC | |||
| pd.set(7, maxk * num_output); // M = maxk*num_output | |||
| pd.set(8, 0); // N = size | |||
| pd.set(9, num_input); // K = inch | |||
| pd.set(10, -1); // constant_broadcast_type_C = null | |||
| pd.set(11, 0); // output_N1M | |||
| pd.set(12, out_elempack); | |||
| gemm->load_param(pd); | |||
| // maxk-inch-outch to pa-maxk-outch/pa-inch | |||
| Mat tmp; | |||
| { | |||
| for (int k = 0; k < maxk; k++) | |||
| Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); | |||
| tmp.create(maxk * num_output, num_input); | |||
| for (int p = 0; p < num_input; p += 1) | |||
| { | |||
| pt[maxk - 1 - k] = p[k]; | |||
| } | |||
| float* g00 = tmp.row(p); | |||
| p += maxk; | |||
| pt += maxk; | |||
| for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) | |||
| { | |||
| for (int k = 0; k < maxk; k++) | |||
| { | |||
| for (int i = 0; i < out_elempack; i++) | |||
| { | |||
| const float* k00 = weight_data_r2.channel(q + i).row(p); | |||
| g00[0] = k00[k]; | |||
| g00++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // src = kw-kh-inch-outch | |||
| // dst = pb-pa-kw-kh-inch/pa-outch/pb | |||
| ncnn::Mat weights[1]; | |||
| weights[0] = tmp; | |||
| gemm->load_model(ModelBinFromMatArray(weights)); | |||
| gemm->create_pipeline(opt); | |||
| } | |||
| else | |||
| { | |||
| Mat weight_data_transposed(weight_data.w); | |||
| { | |||
| float* pt = weight_data_transposed; | |||
| const float* p = weight_data; | |||
| for (int i = 0; i < num_input * num_output; i++) | |||
| { | |||
| for (int k = 0; k < maxk; k++) | |||
| { | |||
| pt[maxk - 1 - k] = p[k]; | |||
| } | |||
| p += maxk; | |||
| pt += maxk; | |||
| } | |||
| } | |||
| // src = kw-kh-inch-outch | |||
| // dst = pb-pa-kw-kh-inch/pa-outch/pb | |||
| Mat weight_data_r2 = weight_data_transposed.reshape(maxk, num_input, num_output); | |||
| weight_data_tm.create(maxk, num_input / elempack, num_output / out_elempack, (size_t)2u * elempack * out_elempack, elempack * out_elempack); | |||
| @@ -475,27 +530,173 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| size_t out_elemsize = elemsize / elempack * out_elempack; | |||
| int out_channels = num_output / out_elempack; | |||
| Mat top_blob_bordered; | |||
| if (pad_left > 0 || pad_right > 0 || pad_top > 0 || pad_bottom > 0 || (output_w > 0 && output_h > 0)) | |||
| { | |||
| top_blob_bordered.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.workspace_allocator); | |||
| top_blob_bordered.create(outw, outh, out_channels, out_elemsize, out_elempack, opt.workspace_allocator); | |||
| } | |||
| else | |||
| { | |||
| top_blob_bordered = top_blob; | |||
| top_blob_bordered.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); | |||
| top_blob_bordered.create(outw, outh, out_channels, out_elemsize, out_elempack, opt.blob_allocator); | |||
| } | |||
| if (top_blob_bordered.empty()) | |||
| return -100; | |||
| const int maxk = kernel_w * kernel_h; | |||
| if (elempack == 8 && out_elempack == 8) | |||
| if (opt.use_sgemm_convolution) | |||
| { | |||
| // sgemm | |||
| Mat bottom_blob_2 = bottom_blob; | |||
| { | |||
| bottom_blob_2.w = bottom_blob.w * bottom_blob.h; | |||
| bottom_blob_2.h = 1; | |||
| } | |||
| Mat top_col2im; | |||
| Option opt_b = opt; | |||
| opt_b.blob_allocator = top_blob_bordered.allocator; | |||
| gemm->forward(bottom_blob_2, top_col2im, opt_b); | |||
| { | |||
| // col2im | |||
| const int gap = (outw * stride_h - w * stride_w) * out_elempack; | |||
| if (out_elempack == 8) | |||
| { | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| const __fp16* sptr = top_col2im.row<const __fp16>(p * maxk); | |||
| Mat outm = top_blob_bordered.channel(p); | |||
| if (bias_data.empty()) | |||
| { | |||
| outm.fill(vdupq_n_f16(0.f)); | |||
| } | |||
| else | |||
| { | |||
| outm.fill(vld1q_f16((const __fp16*)bias_data_fp16 + p * 8)); | |||
| } | |||
| for (int u = 0; u < kernel_h; u++) | |||
| { | |||
| for (int v = 0; v < kernel_w; v++) | |||
| { | |||
| __fp16* ptr = outm.row<__fp16>(dilation_h * u) + dilation_w * v * 8; | |||
| for (int i = 0; i < h; i++) | |||
| { | |||
| for (int j = 0; j < w; j++) | |||
| { | |||
| float16x8_t _val = vld1q_f16(ptr); | |||
| float16x8_t _s = vld1q_f16(sptr); | |||
| _val = vaddq_f16(_val, _s); | |||
| vst1q_f16(ptr, _val); | |||
| ptr += stride_w * 8; | |||
| sptr += 8; | |||
| } | |||
| ptr += gap; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (out_elempack == 4) | |||
| { | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| const __fp16* sptr = top_col2im.row<const __fp16>(p * maxk); | |||
| Mat outm = top_blob_bordered.channel(p); | |||
| if (bias_data.empty()) | |||
| { | |||
| outm.fill(vdup_n_f16(0.f)); | |||
| } | |||
| else | |||
| { | |||
| outm.fill(vld1_f16((const __fp16*)bias_data_fp16 + p * 4)); | |||
| } | |||
| for (int u = 0; u < kernel_h; u++) | |||
| { | |||
| for (int v = 0; v < kernel_w; v++) | |||
| { | |||
| __fp16* ptr = outm.row<__fp16>(dilation_h * u) + dilation_w * v * 4; | |||
| for (int i = 0; i < h; i++) | |||
| { | |||
| for (int j = 0; j < w; j++) | |||
| { | |||
| float16x4_t _val = vld1_f16(ptr); | |||
| float16x4_t _s = vld1_f16(sptr); | |||
| _val = vadd_f16(_val, _s); | |||
| vst1_f16(ptr, _val); | |||
| ptr += stride_w * 4; | |||
| sptr += 4; | |||
| } | |||
| ptr += gap; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (out_elempack == 1) | |||
| { | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| const __fp16* sptr = top_col2im.row<const __fp16>(p * maxk); | |||
| Mat outm = top_blob_bordered.channel(p); | |||
| const __fp16 bias = bias_data_fp16.empty() ? 0.f : ((const __fp16*)bias_data_fp16)[p]; | |||
| outm.fill(bias); | |||
| for (int u = 0; u < kernel_h; u++) | |||
| { | |||
| for (int v = 0; v < kernel_w; v++) | |||
| { | |||
| __fp16* ptr = outm.row<__fp16>(dilation_h * u) + dilation_w * v; | |||
| for (int i = 0; i < h; i++) | |||
| { | |||
| for (int j = 0; j < w; j++) | |||
| { | |||
| ptr[0] += sptr[0]; | |||
| ptr += stride_w; | |||
| sptr += 1; | |||
| } | |||
| ptr += gap; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (activation) | |||
| { | |||
| activation->forward_inplace(top_blob_bordered, opt); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (elempack == 8 && out_elempack == 8) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -575,14 +776,12 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 1 && out_elempack == 8) | |||
| { | |||
| if (elempack == 1 && out_elempack == 8) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -648,14 +847,12 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 4 && out_elempack == 8) | |||
| { | |||
| if (elempack == 4 && out_elempack == 8) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -727,14 +924,12 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 8 && out_elempack == 1) | |||
| { | |||
| if (elempack == 8 && out_elempack == 1) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -803,14 +998,12 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 8 && out_elempack == 4) | |||
| { | |||
| if (elempack == 8 && out_elempack == 4) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -890,14 +1083,12 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 4 && out_elempack == 4) | |||
| { | |||
| if (elempack == 4 && out_elempack == 4) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -969,14 +1160,12 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 1 && out_elempack == 4) | |||
| { | |||
| if (elempack == 1 && out_elempack == 4) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -1042,14 +1231,12 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 4 && out_elempack == 1) | |||
| { | |||
| if (elempack == 4 && out_elempack == 1) | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output / out_elempack; p++) | |||
| for (int p = 0; p < out_channels; p++) | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| @@ -1117,86 +1304,86 @@ int Deconvolution_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, con | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (elempack == 1 && out_elempack == 1) | |||
| { | |||
| if (kernel_w == 4 && kernel_h == 4 && stride_w == 2 && stride_h == 2 && dilation_w == 1 && dilation_h == 1) | |||
| if (elempack == 1 && out_elempack == 1) | |||
| { | |||
| deconv4x4s2_fp16sa_neon(bottom_blob, top_blob_bordered, weight_data_tm, bias_data_fp16, opt); | |||
| if (activation) | |||
| if (kernel_w == 4 && kernel_h == 4 && stride_w == 2 && stride_h == 2 && dilation_w == 1 && dilation_h == 1) | |||
| { | |||
| activation->forward_inplace(top_blob_bordered, opt); | |||
| deconv4x4s2_fp16sa_neon(bottom_blob, top_blob_bordered, weight_data_tm, bias_data_fp16, opt); | |||
| if (activation) | |||
| { | |||
| activation->forward_inplace(top_blob_bordered, opt); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output; p++) | |||
| else | |||
| { | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| for (int i = 0; i < outh; i++) | |||
| // num_output | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int p = 0; p < num_output; p++) | |||
| { | |||
| for (int j = 0; j < outw; j++) | |||
| { | |||
| float sum = 0.f; | |||
| if (bias_term) | |||
| { | |||
| sum = bias_data[p]; | |||
| } | |||
| __fp16* outptr = top_blob_bordered.channel(p); | |||
| const __fp16* kptr = weight_data_tm.channel(p); | |||
| // channels | |||
| for (int q = 0; q < channels; q++) | |||
| for (int i = 0; i < outh; i++) | |||
| { | |||
| for (int j = 0; j < outw; j++) | |||
| { | |||
| const Mat m = bottom_blob.channel(q); | |||
| float sum = 0.f; | |||
| for (int y = 0; y < kernel_h; y++) | |||
| if (bias_term) | |||
| { | |||
| int sys = (i + y * dilation_h - (kernel_extent_h - 1)); | |||
| if (sys < 0 || sys % stride_h != 0) | |||
| continue; | |||
| sum = bias_data[p]; | |||
| } | |||
| int sy = sys / stride_h; | |||
| if (sy >= h) | |||
| continue; | |||
| const __fp16* kptr = weight_data_tm.channel(p); | |||
| const __fp16* sptr = m.row<const __fp16>(sy); | |||
| // channels | |||
| for (int q = 0; q < channels; q++) | |||
| { | |||
| const Mat m = bottom_blob.channel(q); | |||
| for (int x = 0; x < kernel_w; x++) | |||
| for (int y = 0; y < kernel_h; y++) | |||
| { | |||
| int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); | |||
| if (sxs < 0 || sxs % stride_w != 0) | |||
| int sys = (i + y * dilation_h - (kernel_extent_h - 1)); | |||
| if (sys < 0 || sys % stride_h != 0) | |||
| continue; | |||
| int sx = sxs / stride_w; | |||
| if (sx >= w) | |||
| int sy = sys / stride_h; | |||
| if (sy >= h) | |||
| continue; | |||
| __fp16 val = sptr[sx]; | |||
| const __fp16* sptr = m.row<const __fp16>(sy); | |||
| int k = y * kernel_w + x; | |||
| for (int x = 0; x < kernel_w; x++) | |||
| { | |||
| int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); | |||
| if (sxs < 0 || sxs % stride_w != 0) | |||
| continue; | |||
| __fp16 w = kptr[k]; | |||
| int sx = sxs / stride_w; | |||
| if (sx >= w) | |||
| continue; | |||
| sum += val * w; | |||
| __fp16 val = sptr[sx]; | |||
| int k = y * kernel_w + x; | |||
| __fp16 w = kptr[k]; | |||
| sum += val * w; | |||
| } | |||
| } | |||
| kptr += maxk; | |||
| } | |||
| kptr += maxk; | |||
| } | |||
| sum = activation_ss(sum, activation_type, activation_params); | |||
| sum = activation_ss(sum, activation_type, activation_params); | |||
| outptr[j] = (__fp16)sum; | |||
| } | |||
| outptr[j] = (__fp16)sum; | |||
| outptr += outw; | |||
| } | |||
| outptr += outw; | |||
| } | |||
| } | |||
| } | |||
| @@ -409,6 +409,7 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| if (elempack == 4) | |||
| @@ -519,6 +520,7 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max | |||
| } | |||
| } | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| if (elempack == 4) | |||
| @@ -715,6 +717,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| if (elempack == 4) | |||
| @@ -758,6 +761,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int | |||
| } | |||
| } | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| if (elempack == 4) | |||
| @@ -2174,6 +2178,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| } | |||
| int jj = 0; | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -2290,7 +2295,6 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| int kk = 0; | |||
| for (; kk < max_kk; kk += 1) | |||
| { | |||
| #if __aarch64__ | |||
| float32x4_t _pA = vld1q_f32(pA); | |||
| float32x4_t _pB0 = vld1q_f32(pB); | |||
| float32x4_t _pB1 = vld1q_f32(pB + 4); | |||
| @@ -2311,77 +2315,6 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| pA += 4; | |||
| pB += 12; | |||
| #else // __aarch64__ | |||
| #if NCNN_GNU_INLINE_ASM | |||
| asm volatile( | |||
| "pld [%0, #128] \n" | |||
| "pld [%1, #384] \n" | |||
| "vld1.f32 {d6-d7}, [%0 :128]! \n" | |||
| "vldm %1!, {d0-d5} \n" | |||
| "vmla.f32 %q2, q3, d0[0] \n" | |||
| "vmla.f32 %q3, q3, d0[1] \n" | |||
| "vmla.f32 %q4, q3, d1[0] \n" | |||
| "vmla.f32 %q5, q3, d1[1] \n" | |||
| "vmla.f32 %q6, q3, d2[0] \n" | |||
| "vmla.f32 %q7, q3, d2[1] \n" | |||
| "vmla.f32 %q8, q3, d3[0] \n" | |||
| "vmla.f32 %q9, q3, d3[1] \n" | |||
| "vmla.f32 %q10, q3, d4[0] \n" | |||
| "vmla.f32 %q11, q3, d4[1] \n" | |||
| "vmla.f32 %q12, q3, d5[0] \n" | |||
| "vmla.f32 %q13, q3, d5[1] \n" | |||
| : "=r"(pA), | |||
| "=r"(pB), | |||
| "=w"(_sum0), | |||
| "=w"(_sum1), | |||
| "=w"(_sum2), | |||
| "=w"(_sum3), | |||
| "=w"(_sum4), | |||
| "=w"(_sum5), | |||
| "=w"(_sum6), | |||
| "=w"(_sum7), | |||
| "=w"(_sum8), | |||
| "=w"(_sum9), | |||
| "=w"(_suma), | |||
| "=w"(_sumb) | |||
| : "0"(pA), | |||
| "1"(pB), | |||
| "2"(_sum0), | |||
| "3"(_sum1), | |||
| "4"(_sum2), | |||
| "5"(_sum3), | |||
| "6"(_sum4), | |||
| "7"(_sum5), | |||
| "8"(_sum6), | |||
| "9"(_sum7), | |||
| "10"(_sum8), | |||
| "11"(_sum9), | |||
| "12"(_suma), | |||
| "13"(_sumb) | |||
| : "memory", "q0", "q1", "q2", "q3"); | |||
| #else // NCNN_GNU_INLINE_ASM | |||
| float32x4_t _pA = vld1q_f32(pA); | |||
| float32x4_t _pB0 = vld1q_f32(pB); | |||
| float32x4_t _pB1 = vld1q_f32(pB + 4); | |||
| float32x4_t _pB2 = vld1q_f32(pB + 8); | |||
| _sum0 = vmlaq_lane_f32(_sum0, _pA, vget_low_f32(_pB0), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _pA, vget_low_f32(_pB0), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _pA, vget_high_f32(_pB0), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _pA, vget_high_f32(_pB0), 1); | |||
| _sum4 = vmlaq_lane_f32(_sum4, _pA, vget_low_f32(_pB1), 0); | |||
| _sum5 = vmlaq_lane_f32(_sum5, _pA, vget_low_f32(_pB1), 1); | |||
| _sum6 = vmlaq_lane_f32(_sum6, _pA, vget_high_f32(_pB1), 0); | |||
| _sum7 = vmlaq_lane_f32(_sum7, _pA, vget_high_f32(_pB1), 1); | |||
| _sum8 = vmlaq_lane_f32(_sum8, _pA, vget_low_f32(_pB2), 0); | |||
| _sum9 = vmlaq_lane_f32(_sum9, _pA, vget_low_f32(_pB2), 1); | |||
| _suma = vmlaq_lane_f32(_suma, _pA, vget_high_f32(_pB2), 0); | |||
| _sumb = vmlaq_lane_f32(_sumb, _pA, vget_high_f32(_pB2), 1); | |||
| pA += 4; | |||
| pB += 12; | |||
| #endif // NCNN_GNU_INLINE_ASM | |||
| #endif // __aarch64__ | |||
| } | |||
| if (k_end) | |||
| @@ -2439,6 +2372,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| outptr += 48; | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -2905,6 +2839,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| float32x4_t _sum00; | |||
| @@ -2990,21 +2925,13 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| float32x4_t _pB2 = vld1q_f32(pB + 8); | |||
| float32x2_t _pA = vld1_f32(pA); | |||
| #if __aarch64__ | |||
| _sum00 = vfmaq_lane_f32(_sum00, _pB0, _pA, 0); | |||
| _sum01 = vfmaq_lane_f32(_sum01, _pB1, _pA, 0); | |||
| _sum02 = vfmaq_lane_f32(_sum02, _pB2, _pA, 0); | |||
| _sum10 = vfmaq_lane_f32(_sum10, _pB0, _pA, 1); | |||
| _sum11 = vfmaq_lane_f32(_sum11, _pB1, _pA, 1); | |||
| _sum12 = vfmaq_lane_f32(_sum12, _pB2, _pA, 1); | |||
| #else | |||
| _sum00 = vmlaq_lane_f32(_sum00, _pB0, _pA, 0); | |||
| _sum01 = vmlaq_lane_f32(_sum01, _pB1, _pA, 0); | |||
| _sum02 = vmlaq_lane_f32(_sum02, _pB2, _pA, 0); | |||
| _sum10 = vmlaq_lane_f32(_sum10, _pB0, _pA, 1); | |||
| _sum11 = vmlaq_lane_f32(_sum11, _pB1, _pA, 1); | |||
| _sum12 = vmlaq_lane_f32(_sum12, _pB2, _pA, 1); | |||
| #endif | |||
| pA += 2; | |||
| pB += 12; | |||
| @@ -3041,6 +2968,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| outptr += 24; | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| float32x4_t _sum00; | |||
| @@ -3415,6 +3343,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -3460,15 +3389,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| float32x4_t _pB2 = vld1q_f32(pB + 8); | |||
| float32x4_t _pA0 = vdupq_n_f32(pA[0]); | |||
| #if __aarch64__ | |||
| _sum0 = vfmaq_f32(_sum0, _pA0, _pB0); | |||
| _sum1 = vfmaq_f32(_sum1, _pA0, _pB1); | |||
| _sum2 = vfmaq_f32(_sum2, _pA0, _pB2); | |||
| #else | |||
| _sum0 = vmlaq_f32(_sum0, _pA0, _pB0); | |||
| _sum1 = vmlaq_f32(_sum1, _pA0, _pB1); | |||
| _sum2 = vmlaq_f32(_sum2, _pA0, _pB2); | |||
| #endif | |||
| pA += 1; | |||
| pB += 12; | |||
| @@ -3493,6 +3417,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons | |||
| outptr += 12; | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -273,6 +273,7 @@ static void pack_B_tile_fp32_to_bf16(const Mat& B, Mat& BT, int j, int max_jj, i | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| const float* p0 = (const float*)B + (j + jj) * B_hstep + k; | |||
| @@ -363,6 +364,7 @@ static void pack_B_tile_fp32_to_bf16(const Mat& B, Mat& BT, int j, int max_jj, i | |||
| pb++; | |||
| } | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| const float* p0 = (const float*)B + (j + jj) * B_hstep + k; | |||
| @@ -582,6 +584,7 @@ static void transpose_pack_B_tile_fp32_to_bf16(const Mat& B, Mat& BT, int j, int | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| const float* p0 = (const float*)B + k * B_hstep + (j + jj); | |||
| @@ -596,6 +599,7 @@ static void transpose_pack_B_tile_fp32_to_bf16(const Mat& B, Mat& BT, int j, int | |||
| p0 += B_hstep; | |||
| } | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| const float* p0 = (const float*)B + k * B_hstep + (j + jj); | |||
| @@ -1868,6 +1872,7 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| } | |||
| int jj = 0; | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -1984,7 +1989,6 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| int kk = 0; | |||
| for (; kk < max_kk; kk += 1) | |||
| { | |||
| #if __aarch64__ | |||
| float32x4_t _pA = bfloat2float(vld1_u16(pA)); | |||
| float32x4_t _pB0 = bfloat2float(vld1_u16(pB)); | |||
| float32x4_t _pB1 = bfloat2float(vld1_u16(pB + 4)); | |||
| @@ -2005,81 +2009,6 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| pA += 4; | |||
| pB += 12; | |||
| #else // __aarch64__ | |||
| #if NCNN_GNU_INLINE_ASM | |||
| asm volatile( | |||
| "pld [%0, #64] \n" | |||
| "pld [%1, #192] \n" | |||
| "vld1.u16 {d6}, [%0 :64]! \n" | |||
| "vld1.u16 {d2-d4}, [%1 :64]! \n" | |||
| "vshll.u16 q3, d6, #16 \n" | |||
| "vshll.u16 q0, d2, #16 \n" | |||
| "vshll.u16 q1, d3, #16 \n" | |||
| "vshll.u16 q2, d4, #16 \n" | |||
| "vmla.f32 %q2, q3, d0[0] \n" | |||
| "vmla.f32 %q3, q3, d0[1] \n" | |||
| "vmla.f32 %q4, q3, d1[0] \n" | |||
| "vmla.f32 %q5, q3, d1[1] \n" | |||
| "vmla.f32 %q6, q3, d2[0] \n" | |||
| "vmla.f32 %q7, q3, d2[1] \n" | |||
| "vmla.f32 %q8, q3, d3[0] \n" | |||
| "vmla.f32 %q9, q3, d3[1] \n" | |||
| "vmla.f32 %q10, q3, d4[0] \n" | |||
| "vmla.f32 %q11, q3, d4[1] \n" | |||
| "vmla.f32 %q12, q3, d5[0] \n" | |||
| "vmla.f32 %q13, q3, d5[1] \n" | |||
| : "=r"(pA), | |||
| "=r"(pB), | |||
| "=w"(_sum0), | |||
| "=w"(_sum1), | |||
| "=w"(_sum2), | |||
| "=w"(_sum3), | |||
| "=w"(_sum4), | |||
| "=w"(_sum5), | |||
| "=w"(_sum6), | |||
| "=w"(_sum7), | |||
| "=w"(_sum8), | |||
| "=w"(_sum9), | |||
| "=w"(_suma), | |||
| "=w"(_sumb) | |||
| : "0"(pA), | |||
| "1"(pB), | |||
| "2"(_sum0), | |||
| "3"(_sum1), | |||
| "4"(_sum2), | |||
| "5"(_sum3), | |||
| "6"(_sum4), | |||
| "7"(_sum5), | |||
| "8"(_sum6), | |||
| "9"(_sum7), | |||
| "10"(_sum8), | |||
| "11"(_sum9), | |||
| "12"(_suma), | |||
| "13"(_sumb) | |||
| : "memory", "q0", "q1", "q2", "q3"); | |||
| #else | |||
| float32x4_t _pA = bfloat2float(vld1_u16(pA)); | |||
| float32x4_t _pB0 = bfloat2float(vld1_u16(pB)); | |||
| float32x4_t _pB1 = bfloat2float(vld1_u16(pB + 4)); | |||
| float32x4_t _pB2 = bfloat2float(vld1_u16(pB + 8)); | |||
| _sum0 = vmlaq_lane_f32(_sum0, _pA, vget_low_f32(_pB0), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _pA, vget_low_f32(_pB0), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _pA, vget_high_f32(_pB0), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _pA, vget_high_f32(_pB0), 1); | |||
| _sum4 = vmlaq_lane_f32(_sum4, _pA, vget_low_f32(_pB1), 0); | |||
| _sum5 = vmlaq_lane_f32(_sum5, _pA, vget_low_f32(_pB1), 1); | |||
| _sum6 = vmlaq_lane_f32(_sum6, _pA, vget_high_f32(_pB1), 0); | |||
| _sum7 = vmlaq_lane_f32(_sum7, _pA, vget_high_f32(_pB1), 1); | |||
| _sum8 = vmlaq_lane_f32(_sum8, _pA, vget_low_f32(_pB2), 0); | |||
| _sum9 = vmlaq_lane_f32(_sum9, _pA, vget_low_f32(_pB2), 1); | |||
| _suma = vmlaq_lane_f32(_suma, _pA, vget_high_f32(_pB2), 0); | |||
| _sumb = vmlaq_lane_f32(_sumb, _pA, vget_high_f32(_pB2), 1); | |||
| pA += 4; | |||
| pB += 12; | |||
| #endif | |||
| #endif // __aarch64__ | |||
| } | |||
| if (alpha != 1.f) | |||
| @@ -2154,6 +2083,7 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| outptr += 48; | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -2656,6 +2586,7 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| float32x4_t _sum00; | |||
| @@ -2743,21 +2674,13 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| float32x4_t _pA0 = bfloat2float(vdup_n_u16(pA[0])); | |||
| float32x4_t _pA1 = bfloat2float(vdup_n_u16(pA[1])); | |||
| #if __aarch64__ | |||
| _sum00 = vfmaq_f32(_sum00, _pB0, _pA0); | |||
| _sum01 = vfmaq_f32(_sum01, _pB1, _pA0); | |||
| _sum02 = vfmaq_f32(_sum02, _pB2, _pA0); | |||
| _sum10 = vfmaq_f32(_sum10, _pB0, _pA1); | |||
| _sum11 = vfmaq_f32(_sum11, _pB1, _pA1); | |||
| _sum12 = vfmaq_f32(_sum12, _pB2, _pA1); | |||
| #else | |||
| _sum00 = vmlaq_f32(_sum00, _pB0, _pA0); | |||
| _sum01 = vmlaq_f32(_sum01, _pB1, _pA0); | |||
| _sum02 = vmlaq_f32(_sum02, _pB2, _pA0); | |||
| _sum10 = vmlaq_f32(_sum10, _pB0, _pA1); | |||
| _sum11 = vmlaq_f32(_sum11, _pB1, _pA1); | |||
| _sum12 = vmlaq_f32(_sum12, _pB2, _pA1); | |||
| #endif | |||
| pA += 2; | |||
| pB += 12; | |||
| @@ -2805,6 +2728,7 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| outptr += 24; | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| float32x4_t _sum00; | |||
| @@ -3249,6 +3173,7 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -3295,15 +3220,10 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| float32x4_t _pB2 = bfloat2float(vld1_u16(pB + 8)); | |||
| float32x4_t _pA0 = bfloat2float(vdup_n_u16(pA[0])); | |||
| #if __aarch64__ | |||
| _sum0 = vfmaq_f32(_sum0, _pA0, _pB0); | |||
| _sum1 = vfmaq_f32(_sum1, _pA0, _pB1); | |||
| _sum2 = vfmaq_f32(_sum2, _pA0, _pB2); | |||
| #else | |||
| _sum0 = vmlaq_f32(_sum0, _pA0, _pB0); | |||
| _sum1 = vmlaq_f32(_sum1, _pA0, _pB1); | |||
| _sum2 = vmlaq_f32(_sum2, _pA0, _pB2); | |||
| #endif | |||
| pA += 1; | |||
| pB += 12; | |||
| @@ -3336,6 +3256,7 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile | |||
| outptr += 12; | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| float32x4_t _sum0; | |||
| @@ -469,6 +469,7 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| if (elempack == 8) | |||
| @@ -607,6 +608,7 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int | |||
| } | |||
| } | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| if (elempack == 8) | |||
| @@ -917,6 +919,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma | |||
| int jj = 0; | |||
| #if __ARM_NEON | |||
| #if __aarch64__ | |||
| for (; jj + 11 < max_jj; jj += 12) | |||
| { | |||
| if (elempack == 8) | |||
| @@ -992,6 +995,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma | |||
| } | |||
| } | |||
| } | |||
| #endif // __aarch64__ | |||
| for (; jj + 7 < max_jj; jj += 8) | |||
| { | |||
| if (elempack == 8) | |||
| @@ -0,0 +1,246 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "matmul_arm.h" | |||
| #include "layer_type.h" | |||
| #include "cpu.h" | |||
| namespace ncnn { | |||
| MatMul_arm::MatMul_arm() | |||
| { | |||
| #if __ARM_NEON | |||
| #if NCNN_ARM82 | |||
| support_fp16_storage = cpu_support_arm_asimdhp(); | |||
| #endif | |||
| #endif // __ARM_NEON | |||
| #if NCNN_BF16 | |||
| support_bf16_storage = true; | |||
| #endif | |||
| gemm = 0; | |||
| } | |||
| int MatMul_arm::create_pipeline(const Option& opt) | |||
| { | |||
| gemm = ncnn::create_layer(ncnn::LayerType::Gemm); | |||
| ncnn::ParamDict pd; | |||
| pd.set(2, 0); // transA | |||
| pd.set(3, transB); // transB | |||
| pd.set(4, 0); // constantA | |||
| pd.set(5, 0); // constantB | |||
| pd.set(6, 1); // constantC | |||
| pd.set(7, 0); // M = outch | |||
| pd.set(8, 0); // N = size | |||
| pd.set(9, 0); // K = maxk*inch | |||
| pd.set(10, -1); // constant_broadcast_type_C = null | |||
| pd.set(11, 0); // output_N1M | |||
| pd.set(12, 1); // output_elempack | |||
| gemm->load_param(pd); | |||
| gemm->load_model(ModelBinFromMatArray(0)); | |||
| gemm->create_pipeline(opt); | |||
| return 0; | |||
| } | |||
| int MatMul_arm::destroy_pipeline(const Option& opt) | |||
| { | |||
| if (gemm) | |||
| { | |||
| gemm->destroy_pipeline(opt); | |||
| delete gemm; | |||
| gemm = 0; | |||
| } | |||
| return 0; | |||
| } | |||
| int MatMul_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const | |||
| { | |||
| const Mat& A = bottom_blobs[0]; | |||
| const Mat& B = bottom_blobs[1]; | |||
| Mat& top_blob = top_blobs[0]; | |||
| const int Adims = A.dims; | |||
| const int Bdims = B.dims; | |||
| const int max_ABdims = std::max(Adims, Bdims); | |||
| const size_t elemsize = A.elemsize; | |||
| if (Adims == 1 && Bdims == 1) | |||
| { | |||
| // dot product | |||
| std::vector<Mat> _bottom_blobs(2); | |||
| _bottom_blobs[0] = A.reshape(A.w, 1); | |||
| _bottom_blobs[1] = transB ? B.reshape(B.w, 1) : B.reshape(1, B.w); | |||
| gemm->forward(_bottom_blobs, top_blobs, opt); | |||
| top_blob = top_blob.reshape(1, opt.blob_allocator); | |||
| } | |||
| else if (Adims == 2 && Bdims == 2) | |||
| { | |||
| // matrix multiply | |||
| gemm->forward(bottom_blobs, top_blobs, opt); | |||
| } | |||
| else if (Adims == 1 && Bdims == 2) | |||
| { | |||
| // matrix multiply | |||
| std::vector<Mat> _bottom_blobs(2); | |||
| _bottom_blobs[0] = A.reshape(A.w, 1); | |||
| _bottom_blobs[1] = B; | |||
| gemm->forward(_bottom_blobs, top_blobs, opt); | |||
| top_blob = top_blob.reshape(top_blob.w, opt.blob_allocator); | |||
| } | |||
| else if (Adims == 2 && Bdims == 1) | |||
| { | |||
| // matrix multiply | |||
| std::vector<Mat> _bottom_blobs(2); | |||
| _bottom_blobs[0] = A; | |||
| _bottom_blobs[1] = transB ? B.reshape(B.w, 1) : B.reshape(1, B.w); | |||
| gemm->forward(_bottom_blobs, top_blobs, opt); | |||
| top_blob = top_blob.reshape(top_blob.h, opt.blob_allocator); | |||
| } | |||
| else if (Adims == 1 && Bdims > 2) | |||
| { | |||
| // batched matrix multiply | |||
| const int N = transB == 0 ? B.w : B.h; | |||
| const int batch_size = B.d * B.c; | |||
| Mat top_blob1(N, 1, batch_size, elemsize, opt.blob_allocator); | |||
| if (top_blob1.empty()) | |||
| return -100; | |||
| Mat A1 = A.reshape(A.w, 1); | |||
| Mat B1 = B.reshape(B.w, B.h, batch_size); | |||
| for (int p = 0; p < batch_size; p++) | |||
| { | |||
| std::vector<Mat> _bottom_blobs(2); | |||
| _bottom_blobs[0] = A1; | |||
| _bottom_blobs[1] = B1.channel(p); | |||
| std::vector<Mat> _top_blobs(1); | |||
| _top_blobs[0] = top_blob1.channel(p); | |||
| gemm->forward(_bottom_blobs, _top_blobs, opt); | |||
| } | |||
| if (Bdims == 3) | |||
| top_blob = top_blob1.reshape(N, B.d * B.c, opt.blob_allocator); | |||
| else | |||
| top_blob = top_blob1.reshape(N, B.d, B.c, opt.blob_allocator); | |||
| } | |||
| else if (Adims > 2 && Bdims == 1) | |||
| { | |||
| // batched matrix multiply | |||
| const int M = A.h; | |||
| const int batch_size = A.d * A.c; | |||
| Mat top_blob1(1, M, batch_size, elemsize, opt.blob_allocator); | |||
| if (top_blob1.empty()) | |||
| return -100; | |||
| Mat A1 = A.reshape(A.w, A.h, batch_size); | |||
| Mat BT = transB ? B.reshape(B.w, 1) : B.reshape(1, B.w); | |||
| for (int p = 0; p < batch_size; p++) | |||
| { | |||
| std::vector<Mat> _bottom_blobs(2); | |||
| _bottom_blobs[0] = A1.channel(p); | |||
| _bottom_blobs[1] = BT; | |||
| std::vector<Mat> _top_blobs(1); | |||
| _top_blobs[0] = top_blob1.channel(p); | |||
| gemm->forward(_bottom_blobs, _top_blobs, opt); | |||
| } | |||
| if (Adims == 3) | |||
| top_blob = top_blob1.reshape(M, A.d * A.c, opt.blob_allocator); | |||
| else | |||
| top_blob = top_blob1.reshape(M, A.d, A.c, opt.blob_allocator); | |||
| } | |||
| else if (max_ABdims == 3) | |||
| { | |||
| Mat A1 = Adims == 2 ? A.reshape(A.w, A.h, 1) : A; | |||
| Mat B1 = Bdims == 2 ? B.reshape(B.w, B.h, 1) : B; | |||
| const int M = A1.h; | |||
| const int N = transB == 0 ? B1.w : B1.h; | |||
| const int batch_size = std::max(A1.c, B1.c); | |||
| top_blob.create(N, M, batch_size, elemsize, opt.blob_allocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| for (int p = 0; p < batch_size; p++) | |||
| { | |||
| int Ap = A1.c == 1 ? 0 : p; | |||
| int Bp = B1.c == 1 ? 0 : p; | |||
| std::vector<Mat> _bottom_blobs(2); | |||
| _bottom_blobs[0] = A1.channel(Ap); | |||
| _bottom_blobs[1] = B1.channel(Bp); | |||
| std::vector<Mat> _top_blobs(1); | |||
| _top_blobs[0] = top_blob.channel(p); | |||
| gemm->forward(_bottom_blobs, _top_blobs, opt); | |||
| } | |||
| } | |||
| else if (max_ABdims == 4) | |||
| { | |||
| Mat A1 = Adims == 3 ? A.reshape(A.w, A.h, A.c, 1) : A; | |||
| Mat B1 = Bdims == 3 ? B.reshape(B.w, B.h, B.c, 1) : B; | |||
| const int M = A1.h; | |||
| const int N = transB == 0 ? B1.w : B1.h; | |||
| const int batch_size_d = std::max(A1.d, B1.d); | |||
| const int batch_size_c = std::max(A1.c, B1.c); | |||
| top_blob.create(N, M, batch_size_d, batch_size_c, elemsize, opt.blob_allocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| for (int p = 0; p < batch_size_c; p++) | |||
| { | |||
| int Ap = A1.c == 1 ? 0 : p; | |||
| int Bp = B1.c == 1 ? 0 : p; | |||
| for (int q = 0; q < batch_size_d; q++) | |||
| { | |||
| int Ad = A1.d == 1 ? 0 : q; | |||
| int Bd = B1.d == 1 ? 0 : q; | |||
| std::vector<Mat> _bottom_blobs(2); | |||
| _bottom_blobs[0] = A1.channel(Ap).depth(Ad); | |||
| _bottom_blobs[1] = B1.channel(Bp).depth(Bd); | |||
| std::vector<Mat> _top_blobs(1); | |||
| _top_blobs[0] = top_blob.channel(p).depth(q); | |||
| gemm->forward(_bottom_blobs, _top_blobs, opt); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| NCNN_LOGE("impossible matmul %d %d", Adims, Bdims); | |||
| return -1; | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace ncnn | |||
| @@ -0,0 +1,38 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #ifndef LAYER_MATMUL_ARM_H | |||
| #define LAYER_MATMUL_ARM_H | |||
| #include "matmul.h" | |||
| namespace ncnn { | |||
| class MatMul_arm : virtual public MatMul | |||
| { | |||
| public: | |||
| MatMul_arm(); | |||
| virtual int create_pipeline(const Option& opt); | |||
| virtual int destroy_pipeline(const Option& opt); | |||
| virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const; | |||
| public: | |||
| Layer* gemm; | |||
| }; | |||
| } // namespace ncnn | |||
| #endif // LAYER_MATMUL_ARM_H | |||