You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat_arm.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "concat_arm.h"
  15. #include <algorithm>
  16. #include "layer_type.h"
  17. namespace ncnn {
  18. DEFINE_LAYER_CREATOR(Concat_arm)
  19. Concat_arm::Concat_arm()
  20. {
  21. #if __ARM_NEON
  22. support_packing = true;
  23. packing_pack4 = 0;
  24. #endif // __ARM_NEON
  25. }
  26. int Concat_arm::create_pipeline(const Option& opt)
  27. {
  28. #if __ARM_NEON
  29. if (opt.use_packing_layout)
  30. {
  31. {
  32. packing_pack4 = ncnn::create_layer(ncnn::LayerType::Packing);
  33. ncnn::ParamDict pd;
  34. pd.set(0, 4);
  35. packing_pack4->load_param(pd);
  36. packing_pack4->create_pipeline(opt);
  37. }
  38. }
  39. #endif // __ARM_NEON
  40. return 0;
  41. }
  42. int Concat_arm::destroy_pipeline(const Option& opt)
  43. {
  44. #if __ARM_NEON
  45. if (opt.use_packing_layout)
  46. {
  47. if (packing_pack4)
  48. {
  49. packing_pack4->destroy_pipeline(opt);
  50. delete packing_pack4;
  51. packing_pack4 = 0;
  52. }
  53. }
  54. #endif // __ARM_NEON
  55. return 0;
  56. }
  57. int Concat_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
  58. {
  59. int dims = bottom_blobs[0].dims;
  60. #if __ARM_NEON
  61. if (opt.use_packing_layout)
  62. {
  63. if (dims == 1) // axis == 0
  64. {
  65. // concat vector
  66. // total length
  67. size_t elemsize = bottom_blobs[0].elemsize;
  68. int elempack = bottom_blobs[0].elempack;
  69. int top_w = 0;
  70. for (size_t b=0; b<bottom_blobs.size(); b++)
  71. {
  72. const Mat& bottom_blob = bottom_blobs[b];
  73. top_w += bottom_blob.w * bottom_blob.elempack;
  74. }
  75. int out_elempack = top_w % 4 == 0 ? 4 : 1;
  76. size_t out_elemsize = elemsize / elempack * out_elempack;
  77. Mat& top_blob = top_blobs[0];
  78. top_blob.create(top_w / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
  79. if (top_blob.empty())
  80. return -100;
  81. float* outptr = top_blob;
  82. for (size_t b=0; b<bottom_blobs.size(); b++)
  83. {
  84. const Mat& bottom_blob = bottom_blobs[b];
  85. const float* ptr = bottom_blob;
  86. memcpy(outptr, ptr, bottom_blob.w * bottom_blob.elemsize);
  87. outptr += bottom_blob.w * bottom_blob.elempack;
  88. }
  89. return 0;
  90. }
  91. if (dims == 2 && axis == 0)
  92. {
  93. // concat image
  94. int w = bottom_blobs[0].w;
  95. // total height
  96. size_t elemsize = bottom_blobs[0].elemsize;
  97. int elempack = bottom_blobs[0].elempack;
  98. int top_h = 0;
  99. for (size_t b=0; b<bottom_blobs.size(); b++)
  100. {
  101. const Mat& bottom_blob = bottom_blobs[b];
  102. elemsize = std::min(elemsize, bottom_blob.elemsize);
  103. elempack = std::min(elempack, bottom_blob.elempack);
  104. top_h += bottom_blob.h * bottom_blob.elempack;
  105. }
  106. int out_elempack = top_h % 4 == 0 ? 4 : 1;
  107. size_t out_elemsize = elemsize / elempack * out_elempack;
  108. Mat& top_blob = top_blobs[0];
  109. top_blob.create(w, top_h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
  110. if (top_blob.empty())
  111. return -100;
  112. Mat top_blob_unpacked = top_blob;
  113. if (elempack == 1 && out_elempack == 4)
  114. {
  115. top_blob_unpacked.create(w, top_h / elempack, elemsize, elempack, opt.workspace_allocator);
  116. if (top_blob_unpacked.empty())
  117. return -100;
  118. }
  119. float* outptr = top_blob_unpacked;
  120. for (size_t b=0; b<bottom_blobs.size(); b++)
  121. {
  122. const Mat& bottom_blob = bottom_blobs[b];
  123. if (bottom_blob.elempack == 4 && elempack == 1)
  124. {
  125. for (int i=0; i<bottom_blob.h; i++)
  126. {
  127. const float* r0 = bottom_blob.row(i);
  128. float* outptr0 = outptr;
  129. float* outptr1 = outptr + w;
  130. float* outptr2 = outptr + w*2;
  131. float* outptr3 = outptr + w*3;
  132. for (int j=0; j<w; j++)
  133. {
  134. *outptr0++ = r0[0];
  135. *outptr1++ = r0[1];
  136. *outptr2++ = r0[2];
  137. *outptr3++ = r0[3];
  138. r0 += 4;
  139. }
  140. outptr += w * 4;
  141. }
  142. }
  143. else // if (bottom_blob.elempack == 1 && elempack == 1) if (bottom_blob.elempack == 4 && elempack == 4)
  144. {
  145. int size = w * bottom_blob.h;
  146. const float* ptr = bottom_blob;
  147. memcpy(outptr, ptr, size * bottom_blob.elemsize);
  148. outptr += size * bottom_blob.elempack;
  149. }
  150. }
  151. // packing
  152. if (elempack == 1 && out_elempack == 4)
  153. {
  154. packing_pack4->forward(top_blob_unpacked, top_blob, opt);
  155. }
  156. return 0;
  157. }
  158. if (dims == 2 && axis == 1)
  159. {
  160. // interleave image row
  161. int h = bottom_blobs[0].h;
  162. size_t elemsize = bottom_blobs[0].elemsize;
  163. int elempack = bottom_blobs[0].elempack;
  164. // total width
  165. int top_w = 0;
  166. for (size_t b=0; b<bottom_blobs.size(); b++)
  167. {
  168. const Mat& bottom_blob = bottom_blobs[b];
  169. top_w += bottom_blob.w;
  170. }
  171. Mat& top_blob = top_blobs[0];
  172. top_blob.create(top_w, h, elemsize, elempack, opt.blob_allocator);
  173. if (top_blob.empty())
  174. return -100;
  175. #pragma omp parallel for num_threads(opt.num_threads)
  176. for (int i=0; i<h; i++)
  177. {
  178. float* outptr = top_blob.row(i);
  179. for (size_t b=0; b<bottom_blobs.size(); b++)
  180. {
  181. const Mat& bottom_blob = bottom_blobs[b];
  182. const float* ptr = bottom_blob.row(i);
  183. memcpy(outptr, ptr, bottom_blob.w * elemsize);
  184. outptr += bottom_blob.w * elempack;
  185. }
  186. }
  187. return 0;
  188. }
  189. if (dims == 3 && axis == 0)
  190. {
  191. // concat dim
  192. int w = bottom_blobs[0].w;
  193. int h = bottom_blobs[0].h;
  194. // total channels
  195. size_t elemsize = bottom_blobs[0].elemsize;
  196. int elempack = bottom_blobs[0].elempack;
  197. int top_channels = 0;
  198. for (size_t b=0; b<bottom_blobs.size(); b++)
  199. {
  200. const Mat& bottom_blob = bottom_blobs[b];
  201. elemsize = std::min(elemsize, bottom_blob.elemsize);
  202. elempack = std::min(elempack, bottom_blob.elempack);
  203. top_channels += bottom_blob.c * bottom_blob.elempack;
  204. }
  205. int out_elempack = top_channels % 4 == 0 ? 4 : 1;
  206. size_t out_elemsize = elemsize / elempack * out_elempack;
  207. Mat& top_blob = top_blobs[0];
  208. top_blob.create(w, h, top_channels / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
  209. if (top_blob.empty())
  210. return -100;
  211. Mat top_blob_unpacked = top_blob;
  212. if (elempack == 1 && out_elempack == 4)
  213. {
  214. top_blob_unpacked.create(w, h, top_channels / elempack, elemsize, elempack, opt.workspace_allocator);
  215. if (top_blob_unpacked.empty())
  216. return -100;
  217. }
  218. int p = 0;
  219. for (size_t b=0; b<bottom_blobs.size(); b++)
  220. {
  221. const Mat& bottom_blob = bottom_blobs[b];
  222. if (bottom_blob.elempack == 4 && elempack == 1)
  223. {
  224. int size = bottom_blob.w * bottom_blob.h;
  225. for (int q=0; q<bottom_blob.c; q++)
  226. {
  227. const float* r0 = bottom_blob.channel(q);
  228. float* outptr0 = top_blob_unpacked.channel(p);
  229. float* outptr1 = top_blob_unpacked.channel(p+1);
  230. float* outptr2 = top_blob_unpacked.channel(p+2);
  231. float* outptr3 = top_blob_unpacked.channel(p+3);
  232. for (int i=0; i<size; i++)
  233. {
  234. *outptr0++ = r0[0];
  235. *outptr1++ = r0[1];
  236. *outptr2++ = r0[2];
  237. *outptr3++ = r0[3];
  238. r0 += 4;
  239. }
  240. p += 4;
  241. }
  242. }
  243. else // if (bottom_blob.elempack == 1 && elempack == 1) if (bottom_blob.elempack == 4 && elempack == 4)
  244. {
  245. int size = bottom_blob.total();
  246. const float* ptr = bottom_blob;
  247. float* outptr = top_blob_unpacked.channel(p);
  248. memcpy(outptr, ptr, size * bottom_blob.elemsize);
  249. p += bottom_blob.c;
  250. }
  251. }
  252. // packing
  253. if (elempack == 1 && out_elempack == 4)
  254. {
  255. packing_pack4->forward(top_blob_unpacked, top_blob, opt);
  256. }
  257. return 0;
  258. }
  259. if (dims == 3 && axis == 1)
  260. {
  261. // interleave dim height
  262. int w = bottom_blobs[0].w;
  263. int channels = bottom_blobs[0].c;
  264. size_t elemsize = bottom_blobs[0].elemsize;
  265. int elempack = bottom_blobs[0].elempack;
  266. // total height
  267. int top_h = 0;
  268. for (size_t b=0; b<bottom_blobs.size(); b++)
  269. {
  270. const Mat& bottom_blob = bottom_blobs[b];
  271. top_h += bottom_blob.h;
  272. }
  273. Mat& top_blob = top_blobs[0];
  274. top_blob.create(w, top_h, channels, elemsize, elempack, opt.blob_allocator);
  275. if (top_blob.empty())
  276. return -100;
  277. #pragma omp parallel for num_threads(opt.num_threads)
  278. for (int q=0; q<channels; q++)
  279. {
  280. float* outptr = top_blob.channel(q);
  281. for (size_t b=0; b<bottom_blobs.size(); b++)
  282. {
  283. const Mat& bottom_blob = bottom_blobs[b];
  284. int size = bottom_blob.w * bottom_blob.h;
  285. const float* ptr = bottom_blob.channel(q);
  286. memcpy(outptr, ptr, size * elemsize);
  287. outptr += size * elempack;
  288. }
  289. }
  290. return 0;
  291. }
  292. if (dims == 3 && axis == 2)
  293. {
  294. // interleave dim width
  295. int h = bottom_blobs[0].h;
  296. int channels = bottom_blobs[0].c;
  297. size_t elemsize = bottom_blobs[0].elemsize;
  298. int elempack = bottom_blobs[0].elempack;
  299. // total height
  300. int top_w = 0;
  301. for (size_t b=0; b<bottom_blobs.size(); b++)
  302. {
  303. const Mat& bottom_blob = bottom_blobs[b];
  304. top_w += bottom_blob.w;
  305. }
  306. Mat& top_blob = top_blobs[0];
  307. top_blob.create(top_w, h, channels, elemsize, elempack, opt.blob_allocator);
  308. if (top_blob.empty())
  309. return -100;
  310. #pragma omp parallel for num_threads(opt.num_threads)
  311. for (int q=0; q<channels; q++)
  312. {
  313. float* outptr = top_blob.channel(q);
  314. for (int i=0; i<h; i++)
  315. {
  316. for (size_t b=0; b<bottom_blobs.size(); b++)
  317. {
  318. const Mat& bottom_blob = bottom_blobs[b];
  319. const float* ptr = bottom_blob.channel(q).row(i);
  320. memcpy(outptr, ptr, bottom_blob.w * elemsize);
  321. outptr += bottom_blob.w * elempack;
  322. }
  323. }
  324. }
  325. return 0;
  326. }
  327. } // opt.use_packing_layout
  328. #endif // __ARM_NEON
  329. return Concat::forward(bottom_blobs, top_blobs, opt);
  330. }
  331. } // namespace ncnn