You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat_vulkan.cpp 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "concat_vulkan.h"
  15. #include <algorithm>
  16. #include "layer_type.h"
  17. #include "layer_shader_type.h"
  18. namespace ncnn {
  19. DEFINE_LAYER_CREATOR(Concat_vulkan)
  20. Concat_vulkan::Concat_vulkan()
  21. {
  22. support_vulkan = true;
  23. packing_pack4 = 0;
  24. packing_pack8 = 0;
  25. pipeline_concat[0] = 0;
  26. pipeline_concat[1] = 0;
  27. pipeline_concat_pack4[0] = 0;
  28. pipeline_concat_pack4[1] = 0;
  29. pipeline_concat_pack4to1[0] = 0;
  30. pipeline_concat_pack4to1[1] = 0;
  31. pipeline_concat_pack8[0] = 0;
  32. pipeline_concat_pack8[1] = 0;
  33. pipeline_concat_pack8to4[0] = 0;
  34. pipeline_concat_pack8to4[1] = 0;
  35. pipeline_concat_pack8to1[0] = 0;
  36. pipeline_concat_pack8to1[1] = 0;
  37. }
  38. int Concat_vulkan::create_pipeline(const Option& opt)
  39. {
  40. const Mat& shape = bottom_shapes.empty() ? Mat() : bottom_shapes[0];
  41. const Mat& out_shape = top_shapes.empty() ? Mat() : top_shapes[0];
  42. int out_elempack = 1;
  43. if (out_shape.dims == 1) out_elempack = opt.use_shader_pack8 && out_shape.w % 8 == 0 ? 8 : out_shape.w % 4 == 0 ? 4 : 1;
  44. if (out_shape.dims == 2) out_elempack = opt.use_shader_pack8 && out_shape.h % 8 == 0 ? 8 : out_shape.h % 4 == 0 ? 4 : 1;
  45. if (out_shape.dims == 3) out_elempack = opt.use_shader_pack8 && out_shape.c % 8 == 0 ? 8 : out_shape.c % 4 == 0 ? 4 : 1;
  46. int elempack = 1;
  47. if (axis == 0)
  48. {
  49. if (shape.dims == 1) elempack = opt.use_shader_pack8 && shape.w % 8 == 0 ? 8 : shape.w % 4 == 0 ? 4 : 1;
  50. if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1;
  51. if (shape.dims == 3) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1;
  52. for (size_t b=1; b<bottom_shapes.size(); b++)
  53. {
  54. const Mat& shape1 = bottom_shapes[b];
  55. int elempack1 = 1;
  56. if (shape1.dims == 1) elempack1 = opt.use_shader_pack8 && shape1.w % 8 == 0 ? 8 : shape1.w % 4 == 0 ? 4 : 1;
  57. if (shape1.dims == 2) elempack1 = opt.use_shader_pack8 && shape1.h % 8 == 0 ? 8 : shape1.h % 4 == 0 ? 4 : 1;
  58. if (shape1.dims == 3) elempack1 = opt.use_shader_pack8 && shape1.c % 8 == 0 ? 8 : shape1.c % 4 == 0 ? 4 : 1;
  59. elempack = std::min(elempack, elempack1);
  60. }
  61. }
  62. else
  63. {
  64. elempack = out_elempack;
  65. }
  66. size_t elemsize;
  67. if (opt.use_fp16_storage)
  68. {
  69. elemsize = elempack * 2u;
  70. }
  71. else if (opt.use_fp16_packed)
  72. {
  73. elemsize = elempack == 1 ? 4u : elempack * 2u;
  74. }
  75. else
  76. {
  77. elemsize = elempack * 4u;
  78. }
  79. Mat out_shape_unpacked;
  80. if (out_shape.dims == 1) out_shape_unpacked = Mat(out_shape.w / elempack, (void*)0, elemsize, elempack);
  81. if (out_shape.dims == 2) out_shape_unpacked = Mat(out_shape.w, out_shape.h / elempack, (void*)0, elemsize, elempack);
  82. if (out_shape.dims == 3) out_shape_unpacked = Mat(out_shape.w, out_shape.h, out_shape.c / elempack, (void*)0, elemsize, elempack);
  83. std::vector<vk_specialization_type> specializations(1 + 10);
  84. specializations[0].i = axis;
  85. specializations[1 + 0].i = 0;// TODO handle shape_packed for concat2
  86. specializations[1 + 1].i = 0;
  87. specializations[1 + 2].i = 0;
  88. specializations[1 + 3].i = 0;
  89. specializations[1 + 4].i = 0;
  90. specializations[1 + 5].i = out_shape_unpacked.dims;
  91. specializations[1 + 6].i = out_shape_unpacked.w;
  92. specializations[1 + 7].i = out_shape_unpacked.h;
  93. specializations[1 + 8].i = out_shape_unpacked.c;
  94. specializations[1 + 9].i = out_shape_unpacked.cstep;
  95. Mat local_size_xyz;// TODO more precise group size guessed from out_shape_unpacked
  96. if (out_shape_unpacked.dims == 1)
  97. {
  98. local_size_xyz.w = 64;
  99. local_size_xyz.h = 1;
  100. local_size_xyz.c = 1;
  101. }
  102. if (out_shape_unpacked.dims == 2)
  103. {
  104. local_size_xyz.w = 8;
  105. local_size_xyz.h = 8;
  106. local_size_xyz.c = 1;
  107. }
  108. if (out_shape_unpacked.dims == 3)
  109. {
  110. local_size_xyz.w = 4;
  111. local_size_xyz.h = 4;
  112. local_size_xyz.c = 4;
  113. }
  114. // pack1
  115. if (shape.dims == 0 || elempack == 1)
  116. {
  117. pipeline_concat[0] = new Pipeline(vkdev);
  118. pipeline_concat[0]->set_optimal_local_size_xyz(local_size_xyz);
  119. pipeline_concat[0]->create(LayerShaderType::concat, opt, specializations);
  120. pipeline_concat[1] = new Pipeline(vkdev);
  121. pipeline_concat[1]->set_optimal_local_size_xyz(local_size_xyz);
  122. pipeline_concat[1]->create(LayerShaderType::concat, opt, specializations);
  123. }
  124. // pack4
  125. if (shape.dims == 0 || elempack == 4)
  126. {
  127. pipeline_concat_pack4[0] = new Pipeline(vkdev);
  128. pipeline_concat_pack4[0]->set_optimal_local_size_xyz(local_size_xyz);
  129. pipeline_concat_pack4[0]->create(LayerShaderType::concat_pack4, opt, specializations);
  130. pipeline_concat_pack4[1] = new Pipeline(vkdev);
  131. pipeline_concat_pack4[1]->set_optimal_local_size_xyz(local_size_xyz);
  132. pipeline_concat_pack4[1]->create(LayerShaderType::concat_pack4, opt, specializations);
  133. }
  134. // pack4to1
  135. if ((axis == 0 && shape.dims == 0) || elempack == 1)
  136. {
  137. pipeline_concat_pack4to1[0] = new Pipeline(vkdev);
  138. pipeline_concat_pack4to1[0]->set_optimal_local_size_xyz(local_size_xyz);
  139. pipeline_concat_pack4to1[0]->create(LayerShaderType::concat_pack4to1, opt, specializations);
  140. pipeline_concat_pack4to1[1] = new Pipeline(vkdev);
  141. pipeline_concat_pack4to1[1]->set_optimal_local_size_xyz(local_size_xyz);
  142. pipeline_concat_pack4to1[1]->create(LayerShaderType::concat_pack4to1, opt, specializations);
  143. }
  144. // pack8
  145. if (opt.use_shader_pack8 && (shape.dims == 0 || elempack == 8))
  146. {
  147. pipeline_concat_pack8[0] = new Pipeline(vkdev);
  148. pipeline_concat_pack8[0]->set_optimal_local_size_xyz(local_size_xyz);
  149. pipeline_concat_pack8[0]->create(LayerShaderType::concat_pack8, opt, specializations);
  150. pipeline_concat_pack8[1] = new Pipeline(vkdev);
  151. pipeline_concat_pack8[1]->set_optimal_local_size_xyz(local_size_xyz);
  152. pipeline_concat_pack8[1]->create(LayerShaderType::concat_pack8, opt, specializations);
  153. }
  154. // pack8to4
  155. if (opt.use_shader_pack8 && ((axis == 0 && shape.dims == 0) || elempack == 4))
  156. {
  157. pipeline_concat_pack8to4[0] = new Pipeline(vkdev);
  158. pipeline_concat_pack8to4[0]->set_optimal_local_size_xyz(local_size_xyz);
  159. pipeline_concat_pack8to4[0]->create(LayerShaderType::concat_pack8to4, opt, specializations);
  160. pipeline_concat_pack8to4[1] = new Pipeline(vkdev);
  161. pipeline_concat_pack8to4[1]->set_optimal_local_size_xyz(local_size_xyz);
  162. pipeline_concat_pack8to4[1]->create(LayerShaderType::concat_pack8to4, opt, specializations);
  163. }
  164. // pack8to1
  165. if (opt.use_shader_pack8 && ((axis == 0 && shape.dims == 0) || elempack == 1))
  166. {
  167. pipeline_concat_pack8to1[0] = new Pipeline(vkdev);
  168. pipeline_concat_pack8to1[0]->set_optimal_local_size_xyz(local_size_xyz);
  169. pipeline_concat_pack8to1[0]->create(LayerShaderType::concat_pack8to1, opt, specializations);
  170. pipeline_concat_pack8to1[1] = new Pipeline(vkdev);
  171. pipeline_concat_pack8to1[1]->set_optimal_local_size_xyz(local_size_xyz);
  172. pipeline_concat_pack8to1[1]->create(LayerShaderType::concat_pack8to1, opt, specializations);
  173. }
  174. if ((axis == 0 && shape.dims == 0) || (elempack < out_elempack && out_elempack == 4))
  175. {
  176. packing_pack4 = ncnn::create_layer(ncnn::LayerType::Packing);
  177. packing_pack4->vkdev = vkdev;
  178. packing_pack4->bottom_shapes.resize(1);
  179. packing_pack4->bottom_shapes[0] = out_shape_unpacked;
  180. packing_pack4->top_shapes.resize(1);
  181. packing_pack4->top_shapes[0] = out_shape;
  182. ncnn::ParamDict pd;
  183. pd.set(0, 4);
  184. packing_pack4->load_param(pd);
  185. packing_pack4->create_pipeline(opt);
  186. }
  187. if ((opt.use_shader_pack8 && axis == 0 && shape.dims == 0) || (elempack < out_elempack && out_elempack == 8))
  188. {
  189. packing_pack8 = ncnn::create_layer(ncnn::LayerType::Packing);
  190. packing_pack8->vkdev = vkdev;
  191. packing_pack8->bottom_shapes.resize(1);
  192. packing_pack8->bottom_shapes[0] = out_shape_unpacked;
  193. packing_pack8->top_shapes.resize(1);
  194. packing_pack8->top_shapes[0] = out_shape;
  195. ncnn::ParamDict pd;
  196. pd.set(0, 8);
  197. packing_pack8->load_param(pd);
  198. packing_pack8->create_pipeline(opt);
  199. }
  200. return 0;
  201. }
  202. int Concat_vulkan::destroy_pipeline(const Option& opt)
  203. {
  204. if (packing_pack4)
  205. {
  206. packing_pack4->destroy_pipeline(opt);
  207. delete packing_pack4;
  208. packing_pack4 = 0;
  209. }
  210. if (packing_pack8)
  211. {
  212. packing_pack8->destroy_pipeline(opt);
  213. delete packing_pack8;
  214. packing_pack8 = 0;
  215. }
  216. delete pipeline_concat[0];
  217. delete pipeline_concat[1];
  218. pipeline_concat[0] = 0;
  219. pipeline_concat[1] = 0;
  220. delete pipeline_concat_pack4[0];
  221. delete pipeline_concat_pack4[1];
  222. pipeline_concat_pack4[0] = 0;
  223. pipeline_concat_pack4[1] = 0;
  224. delete pipeline_concat_pack4to1[0];
  225. delete pipeline_concat_pack4to1[1];
  226. pipeline_concat_pack4to1[0] = 0;
  227. pipeline_concat_pack4to1[1] = 0;
  228. delete pipeline_concat_pack8[0];
  229. delete pipeline_concat_pack8[1];
  230. pipeline_concat_pack8[0] = 0;
  231. pipeline_concat_pack8[1] = 0;
  232. delete pipeline_concat_pack8to4[0];
  233. delete pipeline_concat_pack8to4[1];
  234. pipeline_concat_pack8to4[0] = 0;
  235. pipeline_concat_pack8to4[1] = 0;
  236. delete pipeline_concat_pack8to1[0];
  237. delete pipeline_concat_pack8to1[1];
  238. pipeline_concat_pack8to1[0] = 0;
  239. pipeline_concat_pack8to1[1] = 0;
  240. return 0;
  241. }
  242. int Concat_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
  243. {
  244. int dims = bottom_blobs[0].dims;
  245. if (dims == 1) // axis == 0
  246. {
  247. // concat vector
  248. // total length
  249. size_t elemsize = bottom_blobs[0].elemsize;
  250. int elempack = bottom_blobs[0].elempack;
  251. int top_w = 0;
  252. for (size_t b=0; b<bottom_blobs.size(); b++)
  253. {
  254. const VkMat& bottom_blob = bottom_blobs[b];
  255. elemsize = std::min(elemsize, bottom_blob.elemsize);
  256. elempack = std::min(elempack, bottom_blob.elempack);
  257. top_w += bottom_blob.w * bottom_blob.elempack;
  258. }
  259. int out_elempack = opt.use_shader_pack8 && top_w % 8 == 0 ? 8 : top_w % 4 == 0 ? 4 : 1;
  260. size_t out_elemsize = elemsize / elempack * out_elempack;
  261. if (opt.use_fp16_packed && !opt.use_fp16_storage)
  262. {
  263. if (out_elempack == 8) out_elemsize = 8*2u;
  264. if (out_elempack == 4) out_elemsize = 4*2u;
  265. if (out_elempack == 1) out_elemsize = 4u;
  266. }
  267. VkMat& top_blob = top_blobs[0];
  268. top_blob.create(top_w / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
  269. if (top_blob.empty())
  270. return -100;
  271. VkMat top_blob_unpacked = top_blob;
  272. if (elempack < out_elempack)
  273. {
  274. top_blob_unpacked.create(top_w / elempack, elemsize, elempack, opt.workspace_vkallocator);
  275. if (top_blob_unpacked.empty())
  276. return -100;
  277. }
  278. int woffset = 0;
  279. for (size_t b=0; b<bottom_blobs.size(); b++)
  280. {
  281. const VkMat& bottom_blob = bottom_blobs[b];
  282. std::vector<VkMat> bindings(2);
  283. bindings[0] = bottom_blob;
  284. bindings[1] = top_blob_unpacked;
  285. std::vector<vk_constant_type> constants(11);
  286. constants[0].i = bottom_blob.dims;
  287. constants[1].i = bottom_blob.w;
  288. constants[2].i = bottom_blob.h;
  289. constants[3].i = bottom_blob.c;
  290. constants[4].i = bottom_blob.cstep;
  291. constants[5].i = top_blob_unpacked.dims;
  292. constants[6].i = top_blob_unpacked.w;
  293. constants[7].i = top_blob_unpacked.h;
  294. constants[8].i = top_blob_unpacked.c;
  295. constants[9].i = top_blob_unpacked.cstep;
  296. constants[10].i = woffset;
  297. const Pipeline* pipeline = 0;
  298. if (bottom_blob.elempack == 1 && elempack == 1)
  299. {
  300. pipeline = pipeline_concat[b%2];
  301. }
  302. else if (bottom_blob.elempack == 4 && elempack == 4)
  303. {
  304. pipeline = pipeline_concat_pack4[b%2];
  305. }
  306. else if (bottom_blob.elempack == 4 && elempack == 1)
  307. {
  308. pipeline = pipeline_concat_pack4to1[b%2];
  309. }
  310. else if (bottom_blob.elempack == 8 && elempack == 8)
  311. {
  312. pipeline = pipeline_concat_pack8[b%2];
  313. }
  314. else if (bottom_blob.elempack == 8 && elempack == 4)
  315. {
  316. pipeline = pipeline_concat_pack8to4[b%2];
  317. }
  318. else if (bottom_blob.elempack == 8 && elempack == 1)
  319. {
  320. pipeline = pipeline_concat_pack8to1[b%2];
  321. }
  322. cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
  323. woffset += bottom_blob.w * bottom_blob.elempack / elempack;
  324. }
  325. // packing
  326. if (elempack < out_elempack)
  327. {
  328. const Layer* packing = out_elempack == 8 ? packing_pack8 : packing_pack4;
  329. packing->forward(top_blob_unpacked, top_blob, cmd, opt);
  330. }
  331. return 0;
  332. }
  333. if (dims == 2 && axis == 0)
  334. {
  335. // concat image
  336. int w = bottom_blobs[0].w;
  337. // total height
  338. size_t elemsize = bottom_blobs[0].elemsize;
  339. int elempack = bottom_blobs[0].elempack;
  340. int top_h = 0;
  341. for (size_t b=0; b<bottom_blobs.size(); b++)
  342. {
  343. const VkMat& bottom_blob = bottom_blobs[b];
  344. elemsize = std::min(elemsize, bottom_blob.elemsize);
  345. elempack = std::min(elempack, bottom_blob.elempack);
  346. top_h += bottom_blob.h * bottom_blob.elempack;
  347. }
  348. int out_elempack = opt.use_shader_pack8 && top_h % 8 == 0 ? 8 : top_h % 4 == 0 ? 4 : 1;
  349. size_t out_elemsize = elemsize / elempack * out_elempack;
  350. if (opt.use_fp16_packed && !opt.use_fp16_storage)
  351. {
  352. if (out_elempack == 8) out_elemsize = 8*2u;
  353. if (out_elempack == 4) out_elemsize = 4*2u;
  354. if (out_elempack == 1) out_elemsize = 4u;
  355. }
  356. VkMat& top_blob = top_blobs[0];
  357. top_blob.create(w, top_h / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
  358. if (top_blob.empty())
  359. return -100;
  360. VkMat top_blob_unpacked = top_blob;
  361. if (elempack < out_elempack)
  362. {
  363. top_blob_unpacked.create(w, top_h / elempack, elemsize, elempack, opt.workspace_vkallocator);
  364. if (top_blob_unpacked.empty())
  365. return -100;
  366. }
  367. int hoffset = 0;
  368. for (size_t b=0; b<bottom_blobs.size(); b++)
  369. {
  370. const VkMat& bottom_blob = bottom_blobs[b];
  371. std::vector<VkMat> bindings(2);
  372. bindings[0] = bottom_blob;
  373. bindings[1] = top_blob_unpacked;
  374. std::vector<vk_constant_type> constants(11);
  375. constants[0].i = bottom_blob.dims;
  376. constants[1].i = bottom_blob.w;
  377. constants[2].i = bottom_blob.h;
  378. constants[3].i = bottom_blob.c;
  379. constants[4].i = bottom_blob.cstep;
  380. constants[5].i = top_blob_unpacked.dims;
  381. constants[6].i = top_blob_unpacked.w;
  382. constants[7].i = top_blob_unpacked.h;
  383. constants[8].i = top_blob_unpacked.c;
  384. constants[9].i = top_blob_unpacked.cstep;
  385. constants[10].i = hoffset;
  386. const Pipeline* pipeline = 0;
  387. if (bottom_blob.elempack == 1 && elempack == 1)
  388. {
  389. pipeline = pipeline_concat[b%2];
  390. }
  391. else if (bottom_blob.elempack == 4 && elempack == 4)
  392. {
  393. pipeline = pipeline_concat_pack4[b%2];
  394. }
  395. else if (bottom_blob.elempack == 4 && elempack == 1)
  396. {
  397. pipeline = pipeline_concat_pack4to1[b%2];
  398. }
  399. else if (bottom_blob.elempack == 8 && elempack == 8)
  400. {
  401. pipeline = pipeline_concat_pack8[b%2];
  402. }
  403. else if (bottom_blob.elempack == 8 && elempack == 4)
  404. {
  405. pipeline = pipeline_concat_pack8to4[b%2];
  406. }
  407. else if (bottom_blob.elempack == 8 && elempack == 1)
  408. {
  409. pipeline = pipeline_concat_pack8to1[b%2];
  410. }
  411. cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
  412. hoffset += bottom_blob.h * bottom_blob.elempack / elempack;
  413. }
  414. // packing
  415. if (elempack < out_elempack)
  416. {
  417. const Layer* packing = out_elempack == 8 ? packing_pack8 : packing_pack4;
  418. packing->forward(top_blob_unpacked, top_blob, cmd, opt);
  419. }
  420. return 0;
  421. }
  422. if (dims == 2 && axis == 1)
  423. {
  424. // interleave image row
  425. int h = bottom_blobs[0].h;
  426. size_t elemsize = bottom_blobs[0].elemsize;
  427. int elempack = bottom_blobs[0].elempack;
  428. // total width
  429. int top_w = 0;
  430. for (size_t b=0; b<bottom_blobs.size(); b++)
  431. {
  432. const VkMat& bottom_blob = bottom_blobs[b];
  433. top_w += bottom_blob.w;
  434. }
  435. VkMat& top_blob = top_blobs[0];
  436. top_blob.create(top_w, h, elemsize, elempack, opt.blob_vkallocator);
  437. if (top_blob.empty())
  438. return -100;
  439. int woffset = 0;
  440. for (size_t b=0; b<bottom_blobs.size(); b++)
  441. {
  442. const VkMat& bottom_blob = bottom_blobs[b];
  443. std::vector<VkMat> bindings(2);
  444. bindings[0] = bottom_blob;
  445. bindings[1] = top_blob;
  446. std::vector<vk_constant_type> constants(11);
  447. constants[0].i = bottom_blob.dims;
  448. constants[1].i = bottom_blob.w;
  449. constants[2].i = bottom_blob.h;
  450. constants[3].i = bottom_blob.c;
  451. constants[4].i = bottom_blob.cstep;
  452. constants[5].i = top_blob.dims;
  453. constants[6].i = top_blob.w;
  454. constants[7].i = top_blob.h;
  455. constants[8].i = top_blob.c;
  456. constants[9].i = top_blob.cstep;
  457. constants[10].i = woffset;
  458. const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b%2]
  459. : elempack == 4 ? pipeline_concat_pack4[b%2]
  460. : pipeline_concat[b%2];
  461. cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
  462. woffset += bottom_blob.w;
  463. }
  464. return 0;
  465. }
  466. if (dims == 3 && axis == 0)
  467. {
  468. // concat dim
  469. int w = bottom_blobs[0].w;
  470. int h = bottom_blobs[0].h;
  471. // total channels
  472. size_t elemsize = bottom_blobs[0].elemsize;
  473. int elempack = bottom_blobs[0].elempack;
  474. int top_channels = 0;
  475. for (size_t b=0; b<bottom_blobs.size(); b++)
  476. {
  477. const VkMat& bottom_blob = bottom_blobs[b];
  478. elemsize = std::min(elemsize, bottom_blob.elemsize);
  479. elempack = std::min(elempack, bottom_blob.elempack);
  480. top_channels += bottom_blob.c * bottom_blob.elempack;
  481. }
  482. int out_elempack = opt.use_shader_pack8 && top_channels % 8 == 0 ? 8 : top_channels % 4 == 0 ? 4 : 1;
  483. size_t out_elemsize = elemsize / elempack * out_elempack;
  484. if (opt.use_fp16_packed && !opt.use_fp16_storage)
  485. {
  486. if (out_elempack == 8) out_elemsize = 8*2u;
  487. if (out_elempack == 4) out_elemsize = 4*2u;
  488. if (out_elempack == 1) out_elemsize = 4u;
  489. }
  490. VkMat& top_blob = top_blobs[0];
  491. top_blob.create(w, h, top_channels / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
  492. if (top_blob.empty())
  493. return -100;
  494. VkMat top_blob_unpacked = top_blob;
  495. if (elempack < out_elempack)
  496. {
  497. top_blob_unpacked.create(w, h, top_channels / elempack, elemsize, elempack, opt.workspace_vkallocator);
  498. if (top_blob_unpacked.empty())
  499. return -100;
  500. }
  501. int coffset = 0;
  502. for (size_t b=0; b<bottom_blobs.size(); b++)
  503. {
  504. const VkMat& bottom_blob = bottom_blobs[b];
  505. std::vector<VkMat> bindings(2);
  506. bindings[0] = bottom_blob;
  507. bindings[1] = top_blob_unpacked;
  508. std::vector<vk_constant_type> constants(11);
  509. constants[0].i = bottom_blob.dims;
  510. constants[1].i = bottom_blob.w;
  511. constants[2].i = bottom_blob.h;
  512. constants[3].i = bottom_blob.c;
  513. constants[4].i = bottom_blob.cstep;
  514. constants[5].i = top_blob_unpacked.dims;
  515. constants[6].i = top_blob_unpacked.w;
  516. constants[7].i = top_blob_unpacked.h;
  517. constants[8].i = top_blob_unpacked.c;
  518. constants[9].i = top_blob_unpacked.cstep;
  519. constants[10].i = coffset;
  520. const Pipeline* pipeline = 0;
  521. if (bottom_blob.elempack == 1 && elempack == 1)
  522. {
  523. pipeline = pipeline_concat[b%2];
  524. }
  525. else if (bottom_blob.elempack == 4 && elempack == 4)
  526. {
  527. pipeline = pipeline_concat_pack4[b%2];
  528. }
  529. else if (bottom_blob.elempack == 4 && elempack == 1)
  530. {
  531. pipeline = pipeline_concat_pack4to1[b%2];
  532. }
  533. else if (bottom_blob.elempack == 8 && elempack == 8)
  534. {
  535. pipeline = pipeline_concat_pack8[b%2];
  536. }
  537. else if (bottom_blob.elempack == 8 && elempack == 4)
  538. {
  539. pipeline = pipeline_concat_pack8to4[b%2];
  540. }
  541. else if (bottom_blob.elempack == 8 && elempack == 1)
  542. {
  543. pipeline = pipeline_concat_pack8to1[b%2];
  544. }
  545. cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
  546. coffset += bottom_blob.c * bottom_blob.elempack / elempack;
  547. }
  548. // packing
  549. if (elempack < out_elempack)
  550. {
  551. const Layer* packing = out_elempack == 8 ? packing_pack8 : packing_pack4;
  552. packing->forward(top_blob_unpacked, top_blob, cmd, opt);
  553. }
  554. return 0;
  555. }
  556. if (dims == 3 && axis == 1)
  557. {
  558. // interleave dim height
  559. int w = bottom_blobs[0].w;
  560. int channels = bottom_blobs[0].c;
  561. size_t elemsize = bottom_blobs[0].elemsize;
  562. int elempack = bottom_blobs[0].elempack;
  563. // total height
  564. int top_h = 0;
  565. for (size_t b=0; b<bottom_blobs.size(); b++)
  566. {
  567. const VkMat& bottom_blob = bottom_blobs[b];
  568. top_h += bottom_blob.h;
  569. }
  570. VkMat& top_blob = top_blobs[0];
  571. top_blob.create(w, top_h, channels, elemsize, elempack, opt.blob_vkallocator);
  572. if (top_blob.empty())
  573. return -100;
  574. int hoffset = 0;
  575. for (size_t b=0; b<bottom_blobs.size(); b++)
  576. {
  577. const VkMat& bottom_blob = bottom_blobs[b];
  578. std::vector<VkMat> bindings(2);
  579. bindings[0] = bottom_blob;
  580. bindings[1] = top_blob;
  581. std::vector<vk_constant_type> constants(11);
  582. constants[0].i = bottom_blob.dims;
  583. constants[1].i = bottom_blob.w;
  584. constants[2].i = bottom_blob.h;
  585. constants[3].i = bottom_blob.c;
  586. constants[4].i = bottom_blob.cstep;
  587. constants[5].i = top_blob.dims;
  588. constants[6].i = top_blob.w;
  589. constants[7].i = top_blob.h;
  590. constants[8].i = top_blob.c;
  591. constants[9].i = top_blob.cstep;
  592. constants[10].i = hoffset;
  593. const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b%2]
  594. : elempack == 4 ? pipeline_concat_pack4[b%2]
  595. : pipeline_concat[b%2];
  596. cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
  597. hoffset += bottom_blob.h;
  598. }
  599. return 0;
  600. }
  601. if (dims == 3 && axis == 2)
  602. {
  603. // interleave dim width
  604. int h = bottom_blobs[0].h;
  605. int channels = bottom_blobs[0].c;
  606. size_t elemsize = bottom_blobs[0].elemsize;
  607. int elempack = bottom_blobs[0].elempack;
  608. // total height
  609. int top_w = 0;
  610. for (size_t b=0; b<bottom_blobs.size(); b++)
  611. {
  612. const VkMat& bottom_blob = bottom_blobs[b];
  613. top_w += bottom_blob.w;
  614. }
  615. VkMat& top_blob = top_blobs[0];
  616. top_blob.create(top_w, h, channels, elemsize, elempack, opt.blob_vkallocator);
  617. if (top_blob.empty())
  618. return -100;
  619. int woffset = 0;
  620. for (size_t b=0; b<bottom_blobs.size(); b++)
  621. {
  622. const VkMat& bottom_blob = bottom_blobs[b];
  623. std::vector<VkMat> bindings(2);
  624. bindings[0] = bottom_blob;
  625. bindings[1] = top_blob;
  626. std::vector<vk_constant_type> constants(11);
  627. constants[0].i = bottom_blob.dims;
  628. constants[1].i = bottom_blob.w;
  629. constants[2].i = bottom_blob.h;
  630. constants[3].i = bottom_blob.c;
  631. constants[4].i = bottom_blob.cstep;
  632. constants[5].i = top_blob.dims;
  633. constants[6].i = top_blob.w;
  634. constants[7].i = top_blob.h;
  635. constants[8].i = top_blob.c;
  636. constants[9].i = top_blob.cstep;
  637. constants[10].i = woffset;
  638. const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b%2]
  639. : elempack == 4 ? pipeline_concat_pack4[b%2]
  640. : pipeline_concat[b%2];
  641. cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
  642. woffset += bottom_blob.w;
  643. }
  644. return 0;
  645. }
  646. return 0;
  647. }
  648. } // namespace ncnn