You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. /**
  2. * \file dnn/src/arm_common/elemwise/binary/algo.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/elemwise/binary/algo.h"
  13. #include "src/arm_common/elemwise_op.h"
  14. #include "src/common/utils.h"
  15. #include "src/naive/handle.h"
  16. #include "midout.h"
  17. MIDOUT_DECL(megdnn_arm_common_elemwise_binary)
  18. using namespace megdnn;
  19. using namespace arm_common;
  20. namespace {
  21. static inline bool is_available_common(Elemwise::Mode mode) {
  22. /**
  23. * Fused sigmoid & tanh may be slower than the naive algo, because the
  24. * time used by neon function `exp_ps_f32` is decided by the input.
  25. */
  26. if (mode == Elemwise::Mode::FUSE_ADD_SIGMOID ||
  27. mode == Elemwise::Mode::FUSE_ADD_TANH) {
  28. return false;
  29. }
  30. return true;
  31. }
  32. } // anonymous namespace
  33. #if MEGDNN_AARCH64
  34. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  35. auto mode = kern_param.mode; \
  36. if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
  37. mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \
  38. mode == Mode::TRUE_DIV || mode == Mode::FUSE_ADD_RELU || \
  39. mode == Mode::FUSE_ADD_H_SWISH) \
  40. return true;
  41. #else
  42. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  43. auto mode = kern_param.mode; \
  44. if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
  45. mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \
  46. mode == Mode::FUSE_ADD_RELU || mode == Mode::FUSE_ADD_H_SWISH) \
  47. return true;
  48. #endif
  49. #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
  50. auto mode = kern_param.mode; \
  51. if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
  52. mode == Mode::SUB || mode == Mode::MUL || mode == Mode::RMULH || \
  53. mode == Mode::FUSE_ADD_RELU) \
  54. return true;
  55. bool ElemwiseImpl::AlgoBinaryVecVec::is_available(const KernParam& kern_param) const {
  56. if (!is_available_common(kern_param.mode) ||
  57. (BcastType::VEC_VEC != kern_param.broad_cast_type))
  58. return false;
  59. auto& elparam = kern_param.binary_elparam;
  60. auto& src0 = elparam[0];
  61. //! exactly match [x, y] + [x, y]
  62. DISPATCH_TYPE("AlgoBinaryVecVec::is_available"_hash);
  63. return false;
  64. }
  65. bool ElemwiseImpl::AlgoBinaryVecScalar::is_available(
  66. const KernParam& kern_param) const {
  67. if (!is_available_common(kern_param.mode) ||
  68. ((BcastType::VEC_SCALAR != kern_param.broad_cast_type) &&
  69. (BcastType::SCALAR_VEC != kern_param.broad_cast_type)))
  70. return false;
  71. auto& elparam = kern_param.binary_elparam;
  72. auto& src0 = elparam[0];
  73. DISPATCH_TYPE("AlgoBinaryVecScalar::is_available"_hash);
  74. return false;
  75. }
  76. bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
  77. const KernParam& kern_param) const {
  78. if (!is_available_common(kern_param.mode) ||
  79. ((BcastType::VEC_BCAST101 != kern_param.broad_cast_type) &&
  80. (BcastType::BCAST101_VEC != kern_param.broad_cast_type)))
  81. return false;
  82. auto& elparam = kern_param.binary_elparam;
  83. auto& src0 = elparam[0];
  84. DISPATCH_TYPE("AlgoBinaryVecBcast101::is_available"_hash);
  85. return false;
  86. }
  87. bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available(
  88. const KernParam& kern_param) const {
  89. if (!is_available_common(kern_param.mode) ||
  90. ((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) &&
  91. (BcastType::BCASTX0X_VEC != kern_param.broad_cast_type)))
  92. return false;
  93. auto& elparam = kern_param.binary_elparam;
  94. auto& src0 = elparam[0];
  95. DISPATCH_TYPE("AlgoBinaryVecBcastX0X::is_available"_hash);
  96. return false;
  97. }
  98. bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available(
  99. const KernParam& kern_param) const {
  100. if (!is_available_common(kern_param.mode) ||
  101. ((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) &&
  102. (BcastType::BCAST111C_VEC != kern_param.broad_cast_type)))
  103. return false;
  104. auto& elparam = kern_param.binary_elparam;
  105. auto& src0 = elparam[0];
  106. DISPATCH_TYPE("AlgoBinaryVecBcast111C::is_available"_hash);
  107. return false;
  108. }
  109. bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
  110. const KernParam& kern_param) const {
  111. if (!is_available_common(kern_param.mode) ||
  112. ((BcastType::VEC_BCAST101xX != kern_param.broad_cast_type) &&
  113. (BcastType::BCAST101xX_VEC != kern_param.broad_cast_type)))
  114. return false;
  115. auto& elparam = kern_param.binary_elparam;
  116. auto& src0 = elparam[0];
  117. DISPATCH_TYPE("AlgoBinaryVecBcast101xX::is_available"_hash);
  118. return false;
  119. }
  120. #undef DISPATCH_MODE_FLOAT
  121. #undef DISPATCH_MODE_INT
  122. #if MEGDNN_AARCH64
  123. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  124. switch (kern_param.mode) { \
  125. DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
  126. DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
  127. DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
  128. DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
  129. DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
  130. DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
  131. DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \
  132. DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
  133. DISPATCH_BINARY( \
  134. FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \
  135. default: \
  136. megdnn_throw(ssprintf( \
  137. "No avaiable algo find for: %d", \
  138. static_cast<int>(kern_param.mode))); \
  139. }
  140. #else
  141. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  142. switch (kern_param.mode) { \
  143. DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
  144. DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
  145. DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
  146. DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
  147. DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
  148. DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
  149. DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
  150. DISPATCH_BINARY( \
  151. FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \
  152. default: \
  153. megdnn_throw(ssprintf( \
  154. "No avaiable algo find for: %d", \
  155. static_cast<int>(kern_param.mode))); \
  156. }
  157. #endif
  158. #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
  159. switch (kern_param.mode) { \
  160. DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
  161. DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
  162. DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
  163. DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
  164. DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
  165. DISPATCH_BINARY(RMULH, _case, _type, _type_midout_id, RmulhOp); \
  166. DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
  167. default: \
  168. megdnn_throw(ssprintf( \
  169. "No avaiable algo find for: %d", \
  170. static_cast<int>(kern_param.mode))); \
  171. }
  172. void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const {
  173. auto& elparam = kern_param.binary_elparam;
  174. auto &src0 = elparam[0], &src1 = elparam[1];
  175. //! exactly match [x, y] + [x, y]
  176. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  177. case Mode::_mode: \
  178. MIDOUT_BEGIN( \
  179. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  180. midout_iv(Mode::_mode), _type_midout_id) { \
  181. thin_function<void( \
  182. const _type*, const _type*, _type*, DType, DType, DType, size_t)> \
  183. run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \
  184. MEGDNN_DISPATCH_CPU_KERN( \
  185. static_cast<naive::HandleImpl*>(kern_param.handle), \
  186. run(static_cast<const _type*>(src0.raw_ptr()), \
  187. static_cast<const _type*>(src1.raw_ptr()), \
  188. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  189. src1.layout.dtype, dst.layout.dtype, \
  190. src0.layout.total_nr_elems())); \
  191. } \
  192. MIDOUT_END(); \
  193. return
  194. auto&& dst = *(kern_param.m_dst);
  195. DISPATCH_TYPE("AlgoBinaryVecVec::exec"_hash);
  196. #undef DISPATCH_BINARY
  197. return;
  198. }
  199. void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const {
  200. auto& elparam = kern_param.binary_elparam;
  201. auto &src0 = elparam[0], &src1 = elparam[1];
  202. auto&& dst = *(kern_param.m_dst);
  203. // Case 2: vector + scalar
  204. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  205. case Mode::_mode: \
  206. MIDOUT_BEGIN( \
  207. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  208. midout_iv(Mode::_mode), _type_midout_id) { \
  209. thin_function<void( \
  210. const _type*, const _type, _type*, DType, DType, DType, size_t)> \
  211. run = OpCallerBinary< \
  212. _op<_type, _type>, BcastType::VEC_SCALAR>::run; \
  213. MEGDNN_DISPATCH_CPU_KERN( \
  214. static_cast<naive::HandleImpl*>(kern_param.handle), \
  215. run(static_cast<const _type*>(src0.raw_ptr()), \
  216. static_cast<const _type*>(src1.raw_ptr())[0], \
  217. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  218. src1.layout.dtype, dst.layout.dtype, \
  219. src0.layout.total_nr_elems())); \
  220. } \
  221. MIDOUT_END(); \
  222. return
  223. if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) {
  224. DISPATCH_TYPE("AlgoBinaryVecScalar::exec_vec_sca"_hash);
  225. }
  226. #undef DISPATCH_BINARY
  227. // scalar + vector
  228. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  229. case Mode::_mode: \
  230. MIDOUT_BEGIN( \
  231. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  232. midout_iv(Mode::_mode), _type_midout_id) { \
  233. thin_function<void( \
  234. const _type, const _type*, _type*, DType, DType, DType, size_t)> \
  235. run = OpCallerBinary< \
  236. _op<_type, _type>, BcastType::SCALAR_VEC>::run; \
  237. MEGDNN_DISPATCH_CPU_KERN( \
  238. static_cast<naive::HandleImpl*>(kern_param.handle), \
  239. run(static_cast<const _type*>(src0.raw_ptr())[0], \
  240. static_cast<const _type*>(src1.raw_ptr()), \
  241. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  242. src1.layout.dtype, dst.layout.dtype, \
  243. src1.layout.total_nr_elems())); \
  244. } \
  245. MIDOUT_END(); \
  246. return
  247. if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) {
  248. DISPATCH_TYPE("AlgoBinaryVecScalar::exec_sca_vec"_hash);
  249. }
  250. #undef DISPATCH_BINARY
  251. return;
  252. }
  253. void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) const {
  254. auto& elparam = kern_param.binary_elparam;
  255. auto &src0 = elparam[0], &src1 = elparam[1];
  256. auto&& dst = *(kern_param.m_dst);
  257. BroadcastChannelInfo binfo;
  258. // Case 3: BcastType::VEC + BCAST_101
  259. if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type &&
  260. is_broadcasted_channel_like(src1.layout, binfo)) {
  261. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  262. case Mode::_mode: \
  263. MIDOUT_BEGIN( \
  264. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  265. midout_iv(Mode::_mode), _type_midout_id) { \
  266. thin_function<void( \
  267. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  268. size_t, size_t)> \
  269. run = OpCallerBinary< \
  270. _op<_type, _type>, BcastType::VEC_BCAST101>::run; \
  271. MEGDNN_DISPATCH_CPU_KERN( \
  272. static_cast<naive::HandleImpl*>(kern_param.handle), \
  273. run(static_cast<const _type*>(src0.raw_ptr()), \
  274. static_cast<const _type*>(src1.raw_ptr()), \
  275. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  276. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  277. binfo.z)); \
  278. } \
  279. MIDOUT_END(); \
  280. return
  281. DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_vec_b"_hash);
  282. #undef DISPATCH_BINARY
  283. }
  284. // BCAST_101 + BcastType::VEC
  285. if (BcastType::BCAST101_VEC == kern_param.broad_cast_type &&
  286. is_broadcasted_channel_like(src0.layout, binfo)) {
  287. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  288. case Mode::_mode: \
  289. MIDOUT_BEGIN( \
  290. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  291. midout_iv(Mode::_mode), _type_midout_id) { \
  292. thin_function<void( \
  293. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  294. size_t, size_t)> \
  295. run = OpCallerBinary< \
  296. _op<_type, _type>, BcastType::BCAST101_VEC>::run; \
  297. MEGDNN_DISPATCH_CPU_KERN( \
  298. static_cast<naive::HandleImpl*>(kern_param.handle), \
  299. run(static_cast<const _type*>(src0.raw_ptr()), \
  300. static_cast<const _type*>(src1.raw_ptr()), \
  301. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  302. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  303. binfo.z)); \
  304. } \
  305. MIDOUT_END(); \
  306. return
  307. DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_b_vec"_hash);
  308. #undef DISPATCH_BINARY
  309. }
  310. return;
  311. }
  312. void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const {
  313. auto& elparam = kern_param.binary_elparam;
  314. auto &src0 = elparam[0], &src1 = elparam[1];
  315. auto&& dst = *(kern_param.m_dst);
  316. BroadcastChannelInfo binfo;
  317. // Case: BcastType::VEC + BCAST_X0X
  318. if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type &&
  319. is_broadcasted_3dim_like(src1.layout, binfo)) {
  320. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  321. case Mode::_mode: \
  322. MIDOUT_BEGIN( \
  323. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  324. midout_iv(Mode::_mode), _type_midout_id) { \
  325. thin_function<void( \
  326. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  327. size_t, size_t)> \
  328. run = OpCallerBinary< \
  329. _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \
  330. MEGDNN_DISPATCH_CPU_KERN( \
  331. static_cast<naive::HandleImpl*>(kern_param.handle), \
  332. run(static_cast<const _type*>(src0.raw_ptr()), \
  333. static_cast<const _type*>(src1.raw_ptr()), \
  334. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  335. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  336. binfo.z)); \
  337. } \
  338. MIDOUT_END(); \
  339. return
  340. DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_vec_b"_hash);
  341. #undef DISPATCH_BINARY
  342. }
  343. // BCAST_X0X + BcastType::VEC
  344. if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type &&
  345. is_broadcasted_3dim_like(src0.layout, binfo)) {
  346. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  347. case Mode::_mode: \
  348. MIDOUT_BEGIN( \
  349. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  350. midout_iv(Mode::_mode), _type_midout_id) { \
  351. thin_function<void( \
  352. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  353. size_t, size_t)> \
  354. run = OpCallerBinary< \
  355. _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \
  356. MEGDNN_DISPATCH_CPU_KERN( \
  357. static_cast<naive::HandleImpl*>(kern_param.handle), \
  358. run(static_cast<const _type*>(src0.raw_ptr()), \
  359. static_cast<const _type*>(src1.raw_ptr()), \
  360. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  361. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  362. binfo.z)); \
  363. } \
  364. MIDOUT_END(); \
  365. return
  366. DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_b_vec"_hash);
  367. #undef DISPATCH_BINARY
  368. }
  369. return;
  370. }
  371. void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const {
  372. auto& elparam = kern_param.binary_elparam;
  373. auto &src0 = elparam[0], &src1 = elparam[1];
  374. auto&& dst = *(kern_param.m_dst);
  375. BroadcastChannelInfo binfo;
  376. // Case extra: BcastType::VEC + BCAST_111C
  377. if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type &&
  378. is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
  379. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  380. case Mode::_mode: \
  381. MIDOUT_BEGIN( \
  382. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  383. midout_iv(Mode::_mode), _type_midout_id) { \
  384. thin_function<void( \
  385. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  386. size_t, size_t)> \
  387. run = OpCallerBinary< \
  388. _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \
  389. MEGDNN_DISPATCH_CPU_KERN( \
  390. static_cast<naive::HandleImpl*>(kern_param.handle), \
  391. run(static_cast<const _type*>(src0.raw_ptr()), \
  392. static_cast<const _type*>(src1.raw_ptr()), \
  393. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  394. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  395. binfo.z)); \
  396. } \
  397. MIDOUT_END(); \
  398. return
  399. DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_vec_b"_hash);
  400. #undef DISPATCH_BINARY
  401. }
  402. // BCAST_111C + BcastType::VEC
  403. if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type &&
  404. is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
  405. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  406. case Mode::_mode: \
  407. MIDOUT_BEGIN( \
  408. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  409. midout_iv(Mode::_mode), _type_midout_id) { \
  410. thin_function<void( \
  411. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  412. size_t, size_t)> \
  413. run = OpCallerBinary< \
  414. _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \
  415. MEGDNN_DISPATCH_CPU_KERN( \
  416. static_cast<naive::HandleImpl*>(kern_param.handle), \
  417. run(static_cast<const _type*>(src0.raw_ptr()), \
  418. static_cast<const _type*>(src1.raw_ptr()), \
  419. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  420. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  421. binfo.z)); \
  422. } \
  423. MIDOUT_END(); \
  424. return
  425. DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_b_vec"_hash);
  426. #undef DISPATCH_BINARY
  427. }
  428. return;
  429. }
  430. void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const {
  431. auto& elparam = kern_param.binary_elparam;
  432. auto &src0 = elparam[0], &src1 = elparam[1];
  433. auto&& dst = *(kern_param.m_dst);
  434. BroadcastChannelInfo binfo;
  435. // BcastType::VEC + BCAST_101X
  436. if (BcastType::VEC_BCAST101xX == kern_param.broad_cast_type) {
  437. megdnn_assert(
  438. is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
  439. is_broadcastedx_channel_like<8>(src1.layout, binfo),
  440. "only nchw44 and nchw88 supported");
  441. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  442. case Mode::_mode: \
  443. MIDOUT_BEGIN( \
  444. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  445. midout_iv(Mode::_mode), _type_midout_id) { \
  446. thin_function<void( \
  447. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  448. size_t, size_t, size_t)> \
  449. run = OpCallerBinary< \
  450. _op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \
  451. MEGDNN_DISPATCH_CPU_KERN( \
  452. static_cast<naive::HandleImpl*>(kern_param.handle), \
  453. run(static_cast<const _type*>(src0.raw_ptr()), \
  454. static_cast<const _type*>(src1.raw_ptr()), \
  455. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  456. src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \
  457. binfo.y, binfo.z)); \
  458. } \
  459. MIDOUT_END(); \
  460. return
  461. size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
  462. DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_vec_b"_hash);
  463. #undef DISPATCH_BINARY
  464. }
  465. // BCAST_101x + BcastType::VEC
  466. if (BcastType::BCAST101xX_VEC == kern_param.broad_cast_type) {
  467. megdnn_assert(
  468. is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
  469. is_broadcastedx_channel_like<8>(src0.layout, binfo),
  470. "only nchw44 and nchw88 supported");
  471. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  472. case Mode::_mode: \
  473. MIDOUT_BEGIN( \
  474. megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  475. midout_iv(Mode::_mode), _type_midout_id) { \
  476. thin_function<void( \
  477. const _type*, const _type*, _type*, DType, DType, DType, size_t, \
  478. size_t, size_t, size_t)> \
  479. run = OpCallerBinary< \
  480. _op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \
  481. MEGDNN_DISPATCH_CPU_KERN( \
  482. static_cast<naive::HandleImpl*>(kern_param.handle), \
  483. run(static_cast<const _type*>(src0.raw_ptr()), \
  484. static_cast<const _type*>(src1.raw_ptr()), \
  485. static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
  486. src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \
  487. binfo.y, binfo.z)); \
  488. } \
  489. MIDOUT_END(); \
  490. return
  491. size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
  492. DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_b_vec"_hash);
  493. #undef DISPATCH_BINARY
  494. }
  495. return;
  496. }
  497. #undef DISPATCH_MODE_FLOAT
  498. #undef DISPATCH_MODE_INT
  499. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台