You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolutiondepthwise_arm.cpp 40 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "convolutiondepthwise_arm.h"
  15. #include "layer_type.h"
  16. #if __ARM_NEON
  17. #include <arm_neon.h>
  18. #include "neon_mathfun.h"
  19. #endif // __ARM_NEON
  20. namespace ncnn {
  21. #include "convolutiondepthwise_3x3.h"
  22. #include "convolutiondepthwise_5x5.h"
  23. #include "convolutiondepthwise_3x3_int8.h"
  24. DEFINE_LAYER_CREATOR(ConvolutionDepthWise_arm)
  25. ConvolutionDepthWise_arm::ConvolutionDepthWise_arm()
  26. {
  27. activation = 0;
  28. }
  29. int ConvolutionDepthWise_arm::create_pipeline(const Option& opt)
  30. {
  31. Option opt_cpu = opt;
  32. opt_cpu.use_vulkan_compute = false;
  33. if (activation_type == 1)
  34. {
  35. activation = ncnn::create_layer(ncnn::LayerType::ReLU);
  36. ncnn::ParamDict pd;
  37. activation->load_param(pd);
  38. }
  39. else if (activation_type == 2)
  40. {
  41. activation = ncnn::create_layer(ncnn::LayerType::ReLU);
  42. ncnn::ParamDict pd;
  43. pd.set(0, activation_params[0]);// slope
  44. activation->load_param(pd);
  45. }
  46. else if (activation_type == 3)
  47. {
  48. activation = ncnn::create_layer(ncnn::LayerType::Clip);
  49. ncnn::ParamDict pd;
  50. pd.set(0, activation_params[0]);// min
  51. pd.set(1, activation_params[1]);// max
  52. activation->load_param(pd);
  53. }
  54. else if (activation_type == 4)
  55. {
  56. activation = ncnn::create_layer(ncnn::LayerType::Sigmoid);
  57. ncnn::ParamDict pd;
  58. activation->load_param(pd);
  59. }
  60. if (activation)
  61. {
  62. activation->create_pipeline(opt_cpu);
  63. }
  64. // create Convolution op for each group
  65. const int maxk = kernel_w * kernel_h;
  66. int channels = (weight_data_size / group) / maxk / (num_output / group) * group;
  67. if (opt.use_packing_layout)
  68. {
  69. // depth-wise
  70. if (channels == group && group == num_output)
  71. {
  72. // pack4
  73. if (num_output % 4 == 0)
  74. {
  75. Mat weight_data_r2 = weight_data.reshape(maxk, group);
  76. convert_packing(weight_data_r2, weight_data_pack4, 4);
  77. }
  78. }
  79. // group convolution
  80. const int channels_g = channels / group;
  81. const int num_output_g = num_output / group;
  82. // pack4
  83. if (channels_g % 4 == 0 && num_output_g % 4 == 0)
  84. {
  85. // src = kw-kh-inch-outch
  86. // dst = 4a-4b-kw-kh-inch/4a-outch/4b
  87. {
  88. Mat weight_data_r2_groups = weight_data.reshape(maxk, channels_g, num_output_g * group);
  89. weight_data_pack4_groups.create(maxk, channels_g/4, num_output_g/4 * group, (size_t)4*16, 16);
  90. for (int g=0; g<group; g++)
  91. {
  92. const Mat weight_data_r2 = weight_data_r2_groups.channel_range(num_output_g * g, num_output_g);
  93. Mat weight_data_pack4_g = weight_data_pack4_groups.channel_range(num_output_g/4 * g, num_output_g/4);
  94. for (int q=0; q+3<num_output_g; q+=4)
  95. {
  96. const Mat k0 = weight_data_r2.channel(q);
  97. const Mat k1 = weight_data_r2.channel(q+1);
  98. const Mat k2 = weight_data_r2.channel(q+2);
  99. const Mat k3 = weight_data_r2.channel(q+3);
  100. Mat g0 = weight_data_pack4_g.channel(q/4);
  101. for (int p=0; p+3<channels_g; p+=4)
  102. {
  103. const float* k00 = k0.row(p);
  104. const float* k01 = k0.row(p+1);
  105. const float* k02 = k0.row(p+2);
  106. const float* k03 = k0.row(p+3);
  107. const float* k10 = k1.row(p);
  108. const float* k11 = k1.row(p+1);
  109. const float* k12 = k1.row(p+2);
  110. const float* k13 = k1.row(p+3);
  111. const float* k20 = k2.row(p);
  112. const float* k21 = k2.row(p+1);
  113. const float* k22 = k2.row(p+2);
  114. const float* k23 = k2.row(p+3);
  115. const float* k30 = k3.row(p);
  116. const float* k31 = k3.row(p+1);
  117. const float* k32 = k3.row(p+2);
  118. const float* k33 = k3.row(p+3);
  119. float* g00 = g0.row(p/4);
  120. for (int k=0; k<maxk; k++)
  121. {
  122. g00[0] = k00[k];
  123. g00[1] = k10[k];
  124. g00[2] = k20[k];
  125. g00[3] = k30[k];
  126. g00[4] = k01[k];
  127. g00[5] = k11[k];
  128. g00[6] = k21[k];
  129. g00[7] = k31[k];
  130. g00[8] = k02[k];
  131. g00[9] = k12[k];
  132. g00[10] = k22[k];
  133. g00[11] = k32[k];
  134. g00[12] = k03[k];
  135. g00[13] = k13[k];
  136. g00[14] = k23[k];
  137. g00[15] = k33[k];
  138. g00 += 16;
  139. }
  140. }
  141. }
  142. }
  143. }
  144. }
  145. // pack1to4
  146. if (channels_g % 4 != 0 && num_output_g % 4 == 0)
  147. {
  148. // src = kw-kh-inch-outch
  149. // dst = 4b-kw-kh-inch-outch/4b
  150. {
  151. Mat weight_data_r2_groups = weight_data.reshape(maxk, channels_g, num_output_g * group);
  152. weight_data_pack1to4_groups.create(maxk, channels_g, num_output_g/4 * group, (size_t)4*4, 4);
  153. for (int g=0; g<group; g++)
  154. {
  155. const Mat weight_data_r2 = weight_data_r2_groups.channel_range(num_output_g * g, num_output_g);
  156. Mat weight_data_pack1to4_g = weight_data_pack1to4_groups.channel_range(num_output_g/4 * g, num_output_g/4);
  157. for (int q=0; q+3<num_output_g; q+=4)
  158. {
  159. const Mat k0 = weight_data_r2.channel(q);
  160. const Mat k1 = weight_data_r2.channel(q+1);
  161. const Mat k2 = weight_data_r2.channel(q+2);
  162. const Mat k3 = weight_data_r2.channel(q+3);
  163. Mat g0 = weight_data_pack1to4_g.channel(q/4);
  164. for (int p=0; p<channels_g; p++)
  165. {
  166. const float* k00 = k0.row(p);
  167. const float* k10 = k1.row(p);
  168. const float* k20 = k2.row(p);
  169. const float* k30 = k3.row(p);
  170. float* g00 = g0.row(p);
  171. for (int k=0; k<maxk; k++)
  172. {
  173. g00[0] = k00[k];
  174. g00[1] = k10[k];
  175. g00[2] = k20[k];
  176. g00[3] = k30[k];
  177. g00 += 4;
  178. }
  179. }
  180. }
  181. }
  182. }
  183. }
  184. // pack4to1
  185. if (channels_g % 4 == 0 && num_output_g % 4 != 0)
  186. {
  187. // src = kw-kh-inch-outch
  188. // dst = 4a-kw-kh-inch/4a-outch
  189. {
  190. Mat weight_data_r2_groups = weight_data.reshape(maxk, channels_g, num_output_g * group);
  191. weight_data_pack4to1_groups.create(maxk, channels_g/4, num_output_g * group, (size_t)4*4, 4);
  192. for (int g=0; g<group; g++)
  193. {
  194. const Mat weight_data_r2 = weight_data_r2_groups.channel_range(num_output_g * g, num_output_g);
  195. Mat weight_data_pack4to1_g = weight_data_pack4to1_groups.channel_range(num_output_g * g, num_output_g);
  196. for (int q=0; q<num_output_g; q++)
  197. {
  198. const Mat k0 = weight_data_r2.channel(q);
  199. Mat g0 = weight_data_pack4to1_g.channel(q);
  200. for (int p=0; p+3<channels_g; p+=4)
  201. {
  202. const float* k00 = k0.row(p);
  203. const float* k01 = k0.row(p+1);
  204. const float* k02 = k0.row(p+2);
  205. const float* k03 = k0.row(p+3);
  206. float* g00 = g0.row(p/4);
  207. for (int k=0; k<maxk; k++)
  208. {
  209. g00[0] = k00[k];
  210. g00[1] = k01[k];
  211. g00[2] = k02[k];
  212. g00[3] = k03[k];
  213. g00 += 4;
  214. }
  215. }
  216. }
  217. }
  218. }
  219. }
  220. } // opt.use_packing_layout
  221. for (int i=0; i<(int)group_ops.size(); i++)
  222. delete group_ops[i];
  223. group_ops.clear();
  224. if (channels == group && group == num_output)
  225. {
  226. // depth-wise specific
  227. if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1)
  228. {
  229. if ((stride_w == 1 && stride_h == 1) || (stride_w == 2 && stride_h == 2))
  230. {
  231. return 0;
  232. }
  233. }
  234. if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && use_int8_inference == false)
  235. {
  236. if ((stride_w == 1 && stride_h == 1) || (stride_w == 2 && stride_h == 2))
  237. {
  238. return 0;
  239. }
  240. }
  241. }
  242. const int channels_g = channels / group;
  243. const int num_output_g = num_output / group;
  244. group_ops.resize(group);
  245. for (int g=0; g<group; g++)
  246. {
  247. Mat weight_data_g = weight_data.range(maxk * channels_g * num_output_g * g, maxk * channels_g * num_output_g);
  248. Mat bias_data_g;
  249. if (bias_term)
  250. bias_data_g = bias_data.range(num_output_g * g, num_output_g);
  251. ncnn::Layer* op = ncnn::create_layer(ncnn::LayerType::Convolution);
  252. // set param
  253. ncnn::ParamDict pd;
  254. pd.set(0, num_output_g);// num_output
  255. pd.set(1, kernel_w);
  256. pd.set(11, kernel_h);
  257. pd.set(2, dilation_w);
  258. pd.set(12, dilation_h);
  259. pd.set(3, stride_w);
  260. pd.set(13, stride_h);
  261. pd.set(4, 0);// pad_w
  262. pd.set(14, 0);// pad_h
  263. pd.set(5, bias_term);
  264. pd.set(6, maxk * channels_g * num_output_g);// weight_data_size
  265. pd.set(8, int8_scale_term);
  266. op->load_param(pd);
  267. // set weights
  268. if (bias_term)
  269. {
  270. ncnn::Mat weights[4];
  271. weights[0] = weight_data_g;
  272. weights[1] = bias_data_g;
  273. if (int8_scale_term)
  274. {
  275. weights[2] = weight_data_int8_scales.range(g, 1);
  276. weights[3] = bottom_blob_int8_scales.range(g, 1);
  277. }
  278. op->load_model(ModelBinFromMatArray(weights));
  279. }
  280. else
  281. {
  282. ncnn::Mat weights[3];
  283. weights[0] = weight_data_g;
  284. if (int8_scale_term)
  285. {
  286. weights[1] = weight_data_int8_scales.range(g, 1);
  287. weights[2] = bottom_blob_int8_scales.range(g, 1);
  288. }
  289. op->load_model(ModelBinFromMatArray(weights));
  290. }
  291. op->create_pipeline(opt_cpu);
  292. group_ops[g] = op;
  293. }
  294. return 0;
  295. }
  296. int ConvolutionDepthWise_arm::destroy_pipeline(const Option& opt)
  297. {
  298. Option opt_cpu = opt;
  299. opt_cpu.use_vulkan_compute = false;
  300. if (activation)
  301. {
  302. activation->destroy_pipeline(opt_cpu);
  303. delete activation;
  304. activation = 0;
  305. }
  306. for (int i=0; i<(int)group_ops.size(); i++)
  307. {
  308. group_ops[i]->destroy_pipeline(opt_cpu);
  309. delete group_ops[i];
  310. }
  311. group_ops.clear();
  312. return 0;
  313. }
  314. int ConvolutionDepthWise_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
  315. {
  316. // convolv with NxN kernel
  317. // value = value + bias
  318. int w = bottom_blob.w;
  319. int h = bottom_blob.h;
  320. int channels = bottom_blob.c;
  321. size_t elemsize = bottom_blob.elemsize;
  322. int packing = bottom_blob.packing;
  323. const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
  324. const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
  325. Mat bottom_blob_unbordered = bottom_blob;
  326. if (use_int8_inference && elemsize != 1)
  327. {
  328. Mat bottom_blob_int8;
  329. bottom_blob_int8.create(w, h, channels, (size_t)1u, opt.workspace_allocator);
  330. if (bottom_blob_int8.empty())
  331. return -100;
  332. const int channels_g = channels / group;
  333. // quantize, scale and round to nearest
  334. #pragma omp parallel for num_threads(opt.num_threads)
  335. for (int g=0; g<group; g++)
  336. {
  337. ncnn::Option opt_g = opt;
  338. opt_g.num_threads = 1;
  339. opt_g.blob_allocator = bottom_blob_int8.allocator;
  340. const Mat bottom_blob_g = bottom_blob.channel_range(channels_g * g, channels_g);
  341. Mat bottom_blob_int8_g = bottom_blob_int8.channel_range(channels_g * g, channels_g);
  342. quantize_ops[g]->forward(bottom_blob_g, bottom_blob_int8_g, opt_g);
  343. }
  344. bottom_blob_unbordered = bottom_blob_int8;
  345. }
  346. Mat bottom_blob_bordered = bottom_blob_unbordered;
  347. if (pad_w > 0 || pad_h > 0)
  348. {
  349. copy_make_border(bottom_blob_unbordered, bottom_blob_bordered, pad_h, pad_h, pad_w, pad_w, BORDER_CONSTANT, 0.f, opt.workspace_allocator, opt.num_threads);
  350. if (bottom_blob_bordered.empty())
  351. return -100;
  352. w = bottom_blob_bordered.w;
  353. h = bottom_blob_bordered.h;
  354. }
  355. else if (pad_w == -233 && pad_h == -233)
  356. {
  357. int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w;
  358. int hpad = kernel_extent_h + (h - 1) / stride_h * stride_h - h;
  359. if (wpad > 0 || hpad > 0)
  360. {
  361. copy_make_border(bottom_blob_unbordered, bottom_blob_bordered, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, BORDER_CONSTANT, 0.f, opt.workspace_allocator, opt.num_threads);
  362. if (bottom_blob_bordered.empty())
  363. return -100;
  364. }
  365. w = bottom_blob_bordered.w;
  366. h = bottom_blob_bordered.h;
  367. }
  368. int outw = (w - kernel_extent_w) / stride_w + 1;
  369. int outh = (h - kernel_extent_h) / stride_h + 1;
  370. int out_packing = num_output % 4 == 0 ? 4 : 1;
  371. size_t out_elemsize = elemsize / packing * out_packing;
  372. if (opt.use_packing_layout)
  373. {
  374. const int maxk = kernel_w * kernel_h;
  375. // kernel offsets
  376. std::vector<int> _space_ofs(maxk);
  377. int* space_ofs = &_space_ofs[0];
  378. {
  379. int p1 = 0;
  380. int p2 = 0;
  381. int gap = w * dilation_h - kernel_w * dilation_w;
  382. for (int i = 0; i < kernel_h; i++)
  383. {
  384. for (int j = 0; j < kernel_w; j++)
  385. {
  386. space_ofs[p1] = p2;
  387. p1++;
  388. p2 += dilation_w;
  389. }
  390. p2 += gap;
  391. }
  392. }
  393. top_blob.create(outw, outh, num_output / out_packing, out_elemsize, out_packing, opt.blob_allocator);
  394. if (top_blob.empty())
  395. return -100;
  396. // depth-wise
  397. if (channels == group / packing && group / packing == num_output / packing)
  398. {
  399. if (packing == 4)
  400. {
  401. #pragma omp parallel for num_threads(opt.num_threads)
  402. for (int g=0; g<group / packing; g++)
  403. {
  404. float* outptr = top_blob.channel(g);
  405. const float* kptr = (const float*)weight_data_pack4 + maxk * g * 4;
  406. const Mat m = bottom_blob_bordered.channel(g);
  407. for (int i = 0; i < outh; i++)
  408. {
  409. for (int j = 0; j < outw; j++)
  410. {
  411. float32x4_t _sum = vdupq_n_f32(0.f);
  412. if (bias_term)
  413. {
  414. _sum = vld1q_f32(((const float*)bias_data) + g * 4);
  415. }
  416. const float* sptr = m.row(i*stride_h) + j*stride_w * 4;
  417. for (int k = 0; k < maxk; k++)
  418. {
  419. float32x4_t _val = vld1q_f32( sptr + space_ofs[k] * 4 );
  420. float32x4_t _w = vld1q_f32( kptr + k * 4 );
  421. _sum = vmlaq_f32(_sum, _val, _w);
  422. }
  423. if (activation_type == 1)
  424. {
  425. float32x4_t _zero = vdupq_n_f32(0.f);
  426. _sum = vmaxq_f32(_sum, _zero);
  427. }
  428. else if (activation_type == 2)
  429. {
  430. float32x4_t _zero = vdupq_n_f32(0.f);
  431. float32x4_t _slope = vdupq_n_f32(activation_params[0]);
  432. uint32x4_t _lemask = vcleq_f32(_sum, _zero);
  433. float32x4_t _ps = vmulq_f32(_sum, _slope);
  434. _sum = vbslq_f32(_lemask, _ps, _sum);
  435. }
  436. else if (activation_type == 3)
  437. {
  438. float32x4_t _min = vdupq_n_f32(activation_params[0]);
  439. float32x4_t _max = vdupq_n_f32(activation_params[1]);
  440. _sum = vmaxq_f32(_sum, _min);
  441. _sum = vminq_f32(_sum, _max);
  442. }
  443. else if (activation_type == 4)
  444. {
  445. float32x4_t _one = vdupq_n_f32(1.f);
  446. _sum = vnegq_f32(_sum);
  447. _sum = exp_ps(_sum);
  448. _sum = vaddq_f32(_sum, _one);
  449. float32x4_t _outp = vrecpeq_f32(_sum);
  450. _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
  451. // _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
  452. _sum = _outp;
  453. }
  454. vst1q_f32(outptr + j * 4, _sum);
  455. }
  456. outptr += outw * 4;
  457. }
  458. }
  459. return 0;
  460. }
  461. }
  462. const int channels_g = channels * packing / group;
  463. const int num_output_g = num_output / group;
  464. // unpacking
  465. Mat bottom_blob_bordered_unpacked = bottom_blob_bordered;
  466. if (packing == 4 && channels_g % 4 != 0)
  467. {
  468. convert_packing(bottom_blob_bordered, bottom_blob_bordered_unpacked, 1, opt.workspace_allocator, opt.num_threads);
  469. }
  470. Mat top_blob_unpacked = top_blob;
  471. if (num_output_g % 4 != 0 && out_packing == 4)
  472. {
  473. top_blob_unpacked.create(outw, outh, num_output, elemsize / packing, 1, opt.workspace_allocator);
  474. if (top_blob_unpacked.empty())
  475. return -100;
  476. }
  477. if (channels_g % 4 == 0 && num_output_g % 4 == 0)
  478. {
  479. #ifdef _WIN32
  480. #pragma omp parallel for num_threads(opt.num_threads)
  481. #else // _WIN32
  482. #pragma omp parallel for collapse(2) num_threads(opt.num_threads)
  483. #endif // _WIN32
  484. for (int g=0; g<group; g++)
  485. {
  486. for (int p=0; p<num_output_g / 4; p++)
  487. {
  488. float* outptr = top_blob_unpacked.channel(g * num_output_g / 4 + p);
  489. const float* weight_data_ptr = (const float*)weight_data_pack4_groups + maxk * channels_g / 4 * num_output_g / 4 * g * 16;
  490. for (int i = 0; i < outh; i++)
  491. {
  492. for (int j = 0; j < outw; j++)
  493. {
  494. float32x4_t _sum = vdupq_n_f32(0.f);
  495. if (bias_term)
  496. {
  497. _sum = vld1q_f32(((const float*)bias_data) + num_output_g * g + p * 4);
  498. }
  499. const float* kptr = weight_data_ptr + maxk * channels_g / 4 * p * 16;
  500. // channels_g
  501. for (int q=0; q<channels_g / 4; q++)
  502. {
  503. const Mat m = bottom_blob_bordered.channel(channels_g / 4 * g + q);
  504. const float* sptr = m.row(i*stride_h) + j*stride_w * 4;
  505. for (int k = 0; k < maxk; k++)
  506. {
  507. float32x4_t _val = vld1q_f32( sptr + space_ofs[k] * 4 );
  508. float32x4_t _w0 = vld1q_f32( kptr );
  509. float32x4_t _w1 = vld1q_f32( kptr + 4 );
  510. float32x4_t _w2 = vld1q_f32( kptr + 8 );
  511. float32x4_t _w3 = vld1q_f32( kptr + 12 );
  512. #if __aarch64__
  513. _sum = vmlaq_laneq_f32(_sum, _w0, _val, 0);
  514. _sum = vmlaq_laneq_f32(_sum, _w1, _val, 1);
  515. _sum = vmlaq_laneq_f32(_sum, _w2, _val, 2);
  516. _sum = vmlaq_laneq_f32(_sum, _w3, _val, 3);
  517. #else
  518. _sum = vmlaq_lane_f32(_sum, _w0, vget_low_f32(_val), 0);
  519. _sum = vmlaq_lane_f32(_sum, _w1, vget_low_f32(_val), 1);
  520. _sum = vmlaq_lane_f32(_sum, _w2, vget_high_f32(_val), 0);
  521. _sum = vmlaq_lane_f32(_sum, _w3, vget_high_f32(_val), 1);
  522. #endif
  523. kptr += 16;
  524. }
  525. }
  526. if (activation_type == 1)
  527. {
  528. float32x4_t _zero = vdupq_n_f32(0.f);
  529. _sum = vmaxq_f32(_sum, _zero);
  530. }
  531. else if (activation_type == 2)
  532. {
  533. float32x4_t _zero = vdupq_n_f32(0.f);
  534. float32x4_t _slope = vdupq_n_f32(activation_params[0]);
  535. uint32x4_t _lemask = vcleq_f32(_sum, _zero);
  536. float32x4_t _ps = vmulq_f32(_sum, _slope);
  537. _sum = vbslq_f32(_lemask, _ps, _sum);
  538. }
  539. else if (activation_type == 3)
  540. {
  541. float32x4_t _min = vdupq_n_f32(activation_params[0]);
  542. float32x4_t _max = vdupq_n_f32(activation_params[1]);
  543. _sum = vmaxq_f32(_sum, _min);
  544. _sum = vminq_f32(_sum, _max);
  545. }
  546. else if (activation_type == 4)
  547. {
  548. float32x4_t _one = vdupq_n_f32(1.f);
  549. _sum = vnegq_f32(_sum);
  550. _sum = exp_ps(_sum);
  551. _sum = vaddq_f32(_sum, _one);
  552. float32x4_t _outp = vrecpeq_f32(_sum);
  553. _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
  554. // _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
  555. _sum = _outp;
  556. }
  557. vst1q_f32(outptr + j * 4, _sum);
  558. }
  559. outptr += outw * 4;
  560. }
  561. }
  562. }
  563. }
  564. if (channels_g % 4 != 0 && num_output_g % 4 == 0)
  565. {
  566. #ifdef _WIN32
  567. #pragma omp parallel for num_threads(opt.num_threads)
  568. #else // _WIN32
  569. #pragma omp parallel for collapse(2) num_threads(opt.num_threads)
  570. #endif // _WIN32
  571. for (int g=0; g<group; g++)
  572. {
  573. for (int p=0; p<num_output_g / 4; p++)
  574. {
  575. float* outptr = top_blob_unpacked.channel(g * num_output_g / 4 + p);
  576. const float* weight_data_ptr = (const float*)weight_data_pack1to4_groups + maxk * channels_g * num_output_g / 4 * g * 4;
  577. for (int i = 0; i < outh; i++)
  578. {
  579. for (int j = 0; j < outw; j++)
  580. {
  581. float32x4_t _sum = vdupq_n_f32(0.f);
  582. if (bias_term)
  583. {
  584. _sum = vld1q_f32(((const float*)bias_data) + (num_output_g / 4 * g + p) * 4);
  585. }
  586. const float* kptr = weight_data_ptr + maxk * channels_g * p * 4;
  587. // channels_g
  588. for (int q=0; q<channels_g; q++)
  589. {
  590. const Mat m = bottom_blob_bordered.channel(channels_g * g + q);
  591. const float* sptr = m.row(i*stride_h) + j*stride_w;
  592. for (int k = 0; k < maxk; k++)
  593. {
  594. float32x4_t _val = vdupq_n_f32( sptr[ space_ofs[k] ] );
  595. float32x4_t _w = vld1q_f32( kptr );
  596. _sum = vmlaq_f32(_sum, _val, _w);
  597. kptr += 4;
  598. }
  599. }
  600. if (activation_type == 1)
  601. {
  602. float32x4_t _zero = vdupq_n_f32(0.f);
  603. _sum = vmaxq_f32(_sum, _zero);
  604. }
  605. else if (activation_type == 2)
  606. {
  607. float32x4_t _zero = vdupq_n_f32(0.f);
  608. float32x4_t _slope = vdupq_n_f32(activation_params[0]);
  609. uint32x4_t _lemask = vcleq_f32(_sum, _zero);
  610. float32x4_t _ps = vmulq_f32(_sum, _slope);
  611. _sum = vbslq_f32(_lemask, _ps, _sum);
  612. }
  613. else if (activation_type == 3)
  614. {
  615. float32x4_t _min = vdupq_n_f32(activation_params[0]);
  616. float32x4_t _max = vdupq_n_f32(activation_params[1]);
  617. _sum = vmaxq_f32(_sum, _min);
  618. _sum = vminq_f32(_sum, _max);
  619. }
  620. else if (activation_type == 4)
  621. {
  622. float32x4_t _one = vdupq_n_f32(1.f);
  623. _sum = vnegq_f32(_sum);
  624. _sum = exp_ps(_sum);
  625. _sum = vaddq_f32(_sum, _one);
  626. float32x4_t _outp = vrecpeq_f32(_sum);
  627. _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
  628. // _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
  629. _sum = _outp;
  630. }
  631. vst1q_f32(outptr + j * 4, _sum);
  632. }
  633. outptr += outw * 4;
  634. }
  635. }
  636. }
  637. }
  638. if (channels_g % 4 == 0 && num_output_g % 4 != 0)
  639. {
  640. #ifdef _WIN32
  641. #pragma omp parallel for num_threads(opt.num_threads)
  642. #else // _WIN32
  643. #pragma omp parallel for collapse(2) num_threads(opt.num_threads)
  644. #endif // _WIN32
  645. for (int g=0; g<group; g++)
  646. {
  647. for (int p=0; p<num_output_g; p++)
  648. {
  649. float* outptr = top_blob_unpacked.channel(g * num_output_g + p);
  650. const float* weight_data_ptr = (const float*)weight_data_pack4to1_groups + maxk * channels_g / 4 * num_output_g * g * 4;
  651. for (int i = 0; i < outh; i++)
  652. {
  653. for (int j = 0; j < outw; j++)
  654. {
  655. float sum = 0.f;
  656. if (bias_term)
  657. sum = bias_data[num_output_g * g + p];
  658. const float* kptr = weight_data_ptr + maxk * channels_g / 4 * p * 4;
  659. // channels_g
  660. for (int q=0; q<channels_g / 4; q++)
  661. {
  662. const Mat m = bottom_blob_bordered.channel(channels_g / 4 * g + q);
  663. const float* sptr = m.row(i*stride_h) + j*stride_w * 4;
  664. for (int k = 0; k < maxk; k++)
  665. {
  666. float32x4_t _val = vld1q_f32( sptr + space_ofs[k] * 4 );
  667. float32x4_t _w = vld1q_f32( kptr );
  668. float32x4_t _s4 = vmulq_f32(_val, _w);
  669. #if __aarch64__
  670. sum += vaddvq_f32(_s4); // dot
  671. #else
  672. float32x2_t _ss = vadd_f32(vget_low_f32(_s4), vget_high_f32(_s4));
  673. _ss = vpadd_f32(_ss, _ss);
  674. sum += vget_lane_f32(_ss, 0);
  675. #endif
  676. kptr += 4;
  677. }
  678. }
  679. if (activation_type == 1)
  680. {
  681. sum = std::max(sum, 0.f);
  682. }
  683. else if (activation_type == 2)
  684. {
  685. float slope = activation_params[0];
  686. sum = sum > 0.f ? sum : sum * slope;
  687. }
  688. else if (activation_type == 3)
  689. {
  690. float min = activation_params[0];
  691. float max = activation_params[1];
  692. if (sum < min)
  693. sum = min;
  694. if (sum > max)
  695. sum = max;
  696. }
  697. else if (activation_type == 4)
  698. {
  699. sum = 1.f / (1.f + exp(-sum));
  700. }
  701. outptr[j] = sum;
  702. }
  703. outptr += outw;
  704. }
  705. }
  706. }
  707. }
  708. // packing
  709. if (num_output_g % 4 != 0 && out_packing == 4)
  710. {
  711. convert_packing(top_blob_unpacked, top_blob, 4, opt.blob_allocator, opt.num_threads);
  712. }
  713. else
  714. {
  715. top_blob = top_blob_unpacked;
  716. }
  717. return 0;
  718. } // opt.use_packing_layout
  719. // int8
  720. if (use_int8_inference)
  721. {
  722. if (use_int8_requantize)
  723. {
  724. Mat top_blob_tm;
  725. top_blob_tm.create(outw, outh, num_output, (size_t)4u, opt.workspace_allocator);
  726. if (top_blob_tm.empty())
  727. return -100;
  728. top_blob.create(outw, outh, num_output, (size_t)1u, opt.blob_allocator);
  729. if (top_blob.empty())
  730. return -100;
  731. // depth-wise
  732. if (channels == group && group == num_output)
  733. {
  734. if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1)
  735. {
  736. if ((stride_w == 1 && stride_h == 1) || (stride_w == 2 && stride_h == 2))
  737. {
  738. if (stride_w == 1 && stride_h == 1)
  739. {
  740. convdw3x3s1_int8_requant_neon(bottom_blob_bordered, top_blob, weight_data, bias_data, requantize_scales, opt);
  741. }
  742. else if (stride_w == 2 && stride_h == 2)
  743. {
  744. convdw3x3s2_int8_requant_neon(bottom_blob_bordered, top_blob, weight_data, bias_data, requantize_scales, opt);
  745. }
  746. if (activation)
  747. {
  748. activation->forward_inplace(top_blob, opt);
  749. }
  750. return 0;
  751. }
  752. }
  753. #pragma omp parallel for num_threads(opt.num_threads)
  754. for (int g=0; g<group; g++)
  755. {
  756. const Mat bottom_blob_bordered_g = bottom_blob_bordered.channel_range(g, 1);
  757. Mat top_blob_tm_g = top_blob_tm.channel_range(g, 1);
  758. const ncnn::Layer* op = group_ops[g];
  759. ncnn::Option opt_g = opt;
  760. opt_g.num_threads = 1;
  761. opt_g.blob_allocator = top_blob.allocator;
  762. // forward
  763. op->forward(bottom_blob_bordered_g, top_blob_tm_g, opt_g);
  764. }
  765. if (activation)
  766. {
  767. activation->forward_inplace(top_blob, opt);
  768. }
  769. return 0;
  770. }
  771. const int channels_g = channels / group;
  772. const int num_output_g = num_output / group;
  773. #pragma omp parallel for num_threads(opt.num_threads)
  774. for (int g=0; g<group; g++)
  775. {
  776. const Mat bottom_blob_bordered_g = bottom_blob_bordered.channel_range(channels_g * g, channels_g);
  777. Mat top_blob_tm_g = top_blob_tm.channel_range(num_output_g * g, num_output_g);
  778. const ncnn::Layer* op = group_ops[g];
  779. ncnn::Option opt_g = opt;
  780. opt_g.blob_allocator = top_blob.allocator;
  781. // forward
  782. op->forward(bottom_blob_bordered_g, top_blob_tm_g, opt_g);
  783. }
  784. }
  785. else
  786. {
  787. top_blob.create(outw, outh, num_output, (size_t)4u, opt.blob_allocator);
  788. if (top_blob.empty())
  789. return -100;
  790. // depth-wise
  791. if (channels == group && group == num_output)
  792. {
  793. if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1)
  794. {
  795. if ((stride_w == 1 && stride_h == 1) || (stride_w == 2 && stride_h == 2))
  796. {
  797. if (stride_w == 1 && stride_h == 1)
  798. {
  799. convdw3x3s1_int8_neon(bottom_blob_bordered, top_blob, weight_data, opt);
  800. }
  801. else if (stride_w == 2 && stride_h == 2)
  802. {
  803. convdw3x3s2_int8_neon(bottom_blob_bordered, top_blob, weight_data, opt);
  804. }
  805. // dequantize, reverse scale inplace
  806. #pragma omp parallel for num_threads(opt.num_threads)
  807. for (int g=0; g<group; g++)
  808. {
  809. ncnn::Option opt_g = opt;
  810. opt_g.num_threads = 1;
  811. opt_g.blob_allocator = top_blob.allocator;
  812. Mat top_blob_g = top_blob.channel(g);
  813. dequantize_ops[g]->forward_inplace(top_blob_g, opt_g);
  814. }
  815. if (activation)
  816. {
  817. activation->forward_inplace(top_blob, opt);
  818. }
  819. return 0;
  820. }
  821. }
  822. #pragma omp parallel for num_threads(opt.num_threads)
  823. for (int g=0; g<group; g++)
  824. {
  825. const Mat bottom_blob_bordered_g = bottom_blob_bordered.channel_range(g, 1);
  826. Mat top_blob_g = top_blob.channel_range(g, 1);
  827. const ncnn::Layer* op = group_ops[g];
  828. ncnn::Option opt_g = opt;
  829. opt_g.num_threads = 1;
  830. opt_g.blob_allocator = top_blob.allocator;
  831. // forward
  832. op->forward(bottom_blob_bordered_g, top_blob_g, opt_g);
  833. }
  834. if (activation)
  835. {
  836. activation->forward_inplace(top_blob, opt);
  837. }
  838. return 0;
  839. }
  840. const int channels_g = channels / group;
  841. const int num_output_g = num_output / group;
  842. #pragma omp parallel for num_threads(opt.num_threads)
  843. for (int g=0; g<group; g++)
  844. {
  845. const Mat bottom_blob_bordered_g = bottom_blob_bordered.channel_range(channels_g * g, channels_g);
  846. Mat top_blob_g = top_blob.channel_range(num_output_g * g, num_output_g);
  847. const ncnn::Layer* op = group_ops[g];
  848. ncnn::Option opt_g = opt;
  849. opt_g.blob_allocator = top_blob.allocator;
  850. // forward
  851. op->forward(bottom_blob_bordered_g, top_blob_g, opt_g);
  852. }
  853. }
  854. if (activation)
  855. {
  856. activation->forward_inplace(top_blob, opt);
  857. }
  858. return 0;
  859. }
  860. // float32
  861. top_blob.create(outw, outh, num_output, elemsize, opt.blob_allocator);
  862. if (top_blob.empty())
  863. return -100;
  864. // depth-wise
  865. if (channels == group && group == num_output)
  866. {
  867. if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1)
  868. {
  869. if (stride_w == 1 && stride_h == 1)
  870. {
  871. convdw3x3s1_neon(bottom_blob_bordered, top_blob, weight_data, bias_data, opt);
  872. }
  873. else if (stride_w == 2 && stride_h == 2)
  874. {
  875. convdw3x3s2_neon(bottom_blob_bordered, top_blob, weight_data, bias_data, opt);
  876. }
  877. if (activation)
  878. {
  879. activation->forward_inplace(top_blob, opt);
  880. }
  881. return 0;
  882. }
  883. if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1)
  884. {
  885. if (stride_w == 1 && stride_h == 1)
  886. {
  887. convdw5x5s1_neon(bottom_blob_bordered, top_blob, weight_data, bias_data, opt);
  888. }
  889. else if (stride_w == 2 && stride_h == 2)
  890. {
  891. convdw5x5s2_neon(bottom_blob_bordered, top_blob, weight_data, bias_data, opt);
  892. }
  893. if (activation)
  894. {
  895. activation->forward_inplace(top_blob, opt);
  896. }
  897. return 0;
  898. }
  899. #pragma omp parallel for num_threads(opt.num_threads)
  900. for (int g=0; g<group; g++)
  901. {
  902. const Mat bottom_blob_bordered_g = bottom_blob_bordered.channel_range(g, 1);
  903. Mat top_blob_g = top_blob.channel_range(g, 1);
  904. const ncnn::Layer* op = group_ops[g];
  905. ncnn::Option opt_g = opt;
  906. opt_g.num_threads = 1;
  907. opt_g.blob_allocator = top_blob.allocator;
  908. // forward
  909. op->forward(bottom_blob_bordered_g, top_blob_g, opt_g);
  910. }
  911. if (activation)
  912. {
  913. activation->forward_inplace(top_blob, opt);
  914. }
  915. return 0;
  916. }
  917. const int channels_g = channels / group;
  918. const int num_output_g = num_output / group;
  919. for (int g=0; g<group; g++)
  920. {
  921. const Mat bottom_blob_bordered_g = bottom_blob_bordered.channel_range(channels_g * g, channels_g);
  922. Mat top_blob_g = top_blob.channel_range(num_output_g * g, num_output_g);
  923. const ncnn::Layer* op = group_ops[g];
  924. ncnn::Option opt_g = opt;
  925. opt_g.blob_allocator = top_blob.allocator;
  926. // forward
  927. op->forward(bottom_blob_bordered_g, top_blob_g, opt_g);
  928. }
  929. if (activation)
  930. {
  931. activation->forward_inplace(top_blob, opt);
  932. }
  933. return 0;
  934. }
  935. } // namespace ncnn