You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #include "src/fallback/reduce/opr_impl.h"
  2. #include "src/common/utils.h"
  3. #include "src/naive/handle.h"
  4. #include "midout.h"
  5. #include "reducer.h"
  6. MIDOUT_DECL(megdnn_fb_reduce_op)
  7. MIDOUT_DECL(megdnn_fb_reduce_c)
  8. MIDOUT_DECL(megdnn_fb_reduce_dtype)
  9. MIDOUT_DECL(megdnn_fallback_reduce_optimized)
  10. namespace {
  11. using namespace megdnn;
  12. template <typename Op>
  13. void reduce_exec_C1(size_t A, size_t B, Op op) MEGDNN_NOEXCEPT {
  14. using wtype = typename Op::wtype;
  15. rep(a, A) {
  16. std::function<wtype(size_t, size_t)> func;
  17. func = [&func, B, &op, a](size_t bl, size_t br) -> wtype {
  18. if (bl + 4096 < br) {
  19. size_t mid = bl + (br - bl) / 2;
  20. return op.apply(func(bl, mid), func(mid, br));
  21. } else {
  22. wtype res = op.INIT;
  23. for (size_t b = bl; b < br; ++b) {
  24. res = op.apply(res, op.read(a * B + b));
  25. }
  26. return res;
  27. }
  28. };
  29. wtype res = func(0, B);
  30. op.write(a, res);
  31. }
  32. }
  33. template <typename Op>
  34. void reduce_exec(size_t A, size_t B, size_t C, Op op) MEGDNN_NOEXCEPT {
  35. using wtype = typename Op::wtype;
  36. rep(a, A) {
  37. rep(c, C) {
  38. std::function<wtype(size_t, size_t)> func;
  39. func = [&func, B, C, &op, a, c](size_t bl, size_t br) -> wtype {
  40. if (bl + 4096 < br) {
  41. size_t mid = bl + (br - bl) / 2;
  42. return op.apply(func(bl, mid), func(mid, br));
  43. } else {
  44. wtype res = op.INIT;
  45. for (size_t b = bl; b < br; ++b) {
  46. res = op.apply(res, op.read(a * B * C + b * C + c));
  47. }
  48. return res;
  49. }
  50. };
  51. wtype res = func(0, B);
  52. op.write(a * C + c, res);
  53. }
  54. }
  55. }
  56. } // anonymous namespace
  57. namespace megdnn {
  58. namespace fallback {
  59. size_t ReduceImpl::get_workspace_in_bytes(
  60. const TensorLayout& src, const TensorLayout& dst) {
  61. MEGDNN_MARK_USED_VAR(src);
  62. MEGDNN_MARK_USED_VAR(dst);
  63. if (src.dtype.enumv() == DTypeEnum::Float32 &&
  64. (param().mode == Mode::MEAN || param().mode == Mode::SUM ||
  65. param().mode == Mode::SUM_SQR)) {
  66. size_t A, B, C;
  67. reduce::get_ABC(src, A, B, C, param().axis);
  68. if (C == 1) {
  69. // Using B = 247 as an example, you can understand why these parameters exist
  70. size_t _60xT_in_4 = (60 * 3) / 4; // T = 3
  71. size_t _60xX_in_4 = 4; // 0 < X < T, X = 1,2.
  72. size_t _XXxT_in_4 = 4;
  73. return ((B / _60xT_in_4 + _60xX_in_4 + _XXxT_in_4) * sizeof(float));
  74. }
  75. }
  76. return naive::ReduceForwardImpl::get_workspace_in_bytes(src, dst);
  77. }
  78. void ReduceImpl::exec(
  79. _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  80. check_exec(src.layout, dst.layout, workspace.size);
  81. if (!exec_optimized(src, dst, workspace)) {
  82. return exec_fallback(src, dst, workspace);
  83. }
  84. }
  85. void ReduceImpl::exec_fallback(
  86. _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  87. using namespace reduce;
  88. using Mode = Param::Mode;
  89. check_exec(src.layout, dst.layout, workspace.size);
  90. size_t A, B, C;
  91. get_ABC(src.layout, A, B, C, param().axis);
  92. #define cb_by_op(src_type, dst_type, _wtype, mode_, Op_, kern_func) \
  93. if (param().mode == mode_) { \
  94. typedef DTypeTrait<src_type>::ctype src_ctype; \
  95. typedef DTypeTrait<dst_type>::ctype dst_ctype; \
  96. typedef DTypeTrait<_wtype>::ctype wtype; \
  97. Op_<src_ctype, dst_ctype, wtype> op(src.get_ref_ptr(), dst.get_ref_ptr(), B); \
  98. MEGDNN_DISPATCH_CPU_KERN_OPR({ kern_func; }); \
  99. return; \
  100. }
  101. #define cb_by_dtype(dtype_, kern_func, type_tuple) \
  102. if (dtype_() == src.layout.dtype) { \
  103. MIDOUT_BEGIN(megdnn_fb_reduce_op, midout_iv(0)) { \
  104. cb_by_op(type_tuple, Mode::SUM, SumOp, kern_func); \
  105. } \
  106. MIDOUT_END(); \
  107. MIDOUT_BEGIN(megdnn_fb_reduce_op, midout_iv(1)) { \
  108. cb_by_op(type_tuple, Mode::SUM_SQR, SumSqrOp, kern_func); \
  109. } \
  110. MIDOUT_END(); \
  111. MIDOUT_BEGIN(megdnn_fb_reduce_op, midout_iv(2)) { \
  112. cb_by_op(type_tuple, Mode::PRODUCT, ProdOp, kern_func); \
  113. } \
  114. MIDOUT_END(); \
  115. MIDOUT_BEGIN(megdnn_fb_reduce_op, midout_iv(3)) { \
  116. cb_by_op(type_tuple, Mode::MIN, MinOp, kern_func); \
  117. } \
  118. MIDOUT_END(); \
  119. MIDOUT_BEGIN(megdnn_fb_reduce_op, midout_iv(4)) { \
  120. cb_by_op(type_tuple, Mode::MAX, MaxOp, kern_func); \
  121. } \
  122. MIDOUT_END(); \
  123. MIDOUT_BEGIN(megdnn_fb_reduce_op, midout_iv(5)) { \
  124. cb_by_op(type_tuple, Mode::MEAN, MeanOp, kern_func); \
  125. } \
  126. MIDOUT_END(); \
  127. }
  128. #if !MEGDNN_DISABLE_FLOAT16
  129. #define cb_by_data_type(dtype_, data_type, kern_func) \
  130. if (data_type == DataType::FLOAT_O16xC32) { \
  131. MIDOUT_BEGIN(megdnn_fb_reduce_dtype, midout_iv(0)){cb_by_dtype( \
  132. dtype_, kern_func, \
  133. dtype_ MEGDNN_COMMA dt_float16 MEGDNN_COMMA float)} MIDOUT_END(); \
  134. } \
  135. if (data_type == DataType::FLOAT_O32xC32) { \
  136. MIDOUT_BEGIN(megdnn_fb_reduce_dtype, midout_iv(1)){cb_by_dtype( \
  137. dtype_, kern_func, \
  138. dtype_ MEGDNN_COMMA float MEGDNN_COMMA float)} MIDOUT_END(); \
  139. } \
  140. if (data_type == DataType::DEFAULT) { \
  141. MIDOUT_BEGIN(megdnn_fb_reduce_dtype, midout_iv(2)){cb_by_dtype( \
  142. dtype_, kern_func, \
  143. dtype_ MEGDNN_COMMA dtype_ MEGDNN_COMMA dtype_)} MIDOUT_END(); \
  144. }
  145. #else
  146. #define cb_by_data_type(dtype_, data_type, kern_func) \
  147. if (data_type == DataType::FLOAT_O32xC32) { \
  148. MIDOUT_BEGIN(megdnn_fb_reduce_dtype, midout_iv(0)){cb_by_dtype( \
  149. dtype_, kern_func, \
  150. dtype_ MEGDNN_COMMA float MEGDNN_COMMA float)} MIDOUT_END(); \
  151. } \
  152. if (data_type == DataType::DEFAULT) { \
  153. MIDOUT_BEGIN(megdnn_fb_reduce_dtype, midout_iv(1)){cb_by_dtype( \
  154. dtype_, kern_func, \
  155. dtype_ MEGDNN_COMMA dtype_ MEGDNN_COMMA dtype_)} MIDOUT_END(); \
  156. }
  157. #endif
  158. #define cb_by_c(dtype_, C) \
  159. if (C == 1) { \
  160. MIDOUT_BEGIN(megdnn_fb_reduce_c, midout_iv(0)){cb_by_data_type( \
  161. dtype_, param().data_type, \
  162. reduce_exec_C1(A MEGDNN_COMMA B MEGDNN_COMMA op))} MIDOUT_END(); \
  163. } else { \
  164. MIDOUT_BEGIN(megdnn_fb_reduce_c, midout_iv(1)){cb_by_data_type( \
  165. dtype_, param().data_type, \
  166. reduce_exec(A MEGDNN_COMMA B MEGDNN_COMMA C MEGDNN_COMMA \
  167. op))} MIDOUT_END(); \
  168. }
  169. #define cb_all(dtype_) cb_by_c(dtype_, C)
  170. MEGDNN_FOREACH_COMPUTING_DTYPE(cb_all);
  171. #undef cb_all
  172. #undef cb_by_c
  173. #undef cb_by_data_type
  174. #undef cb_by_op
  175. naive::ReduceForwardImpl::exec(src, dst, workspace);
  176. }
  177. bool ReduceImpl::exec_optimized(
  178. _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  179. size_t A, B, C;
  180. reduce::get_ABC(src.layout, A, B, C, param().axis);
  181. bool execed = false;
  182. using Mode = param::Reduce::Mode;
  183. #define DISPATCH_FUNC(Reducer, dtype, ctype, comp_type) \
  184. if (C == 1) { \
  185. using _Reducer = Reducer<dtype, ctype, comp_type, true>; \
  186. using _ReducerC1SmallB = Reducer<dtype, ctype, comp_type, false>; \
  187. std::function<void( \
  188. const ctype*, ctype*, DType, size_t, size_t, size_t, \
  189. _megdnn_workspace)> \
  190. do_reduce = Exec<_Reducer, true>::do_reduce; \
  191. if (B == 2) \
  192. do_reduce = ExecC1SmallB<_ReducerC1SmallB, ctype, 2>::do_reduce; \
  193. if (B == 3) \
  194. do_reduce = ExecC1SmallB<_ReducerC1SmallB, ctype, 3>::do_reduce; \
  195. if (B == 4) \
  196. do_reduce = ExecC1SmallB<_ReducerC1SmallB, ctype, 4>::do_reduce; \
  197. MIDOUT_BEGIN( \
  198. megdnn_fallback_reduce_optimized, ctype, dtype, comp_type, \
  199. midout_iv(0)) { \
  200. MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \
  201. reinterpret_cast<ctype*>(src.raw_ptr()), \
  202. reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C, \
  203. workspace)); \
  204. execed = true; \
  205. } \
  206. MIDOUT_END(); \
  207. } else { \
  208. using _Reducer = Reducer<dtype, ctype, comp_type, false>; \
  209. std::function<void( \
  210. const ctype*, ctype*, DType, size_t, size_t, size_t, \
  211. _megdnn_workspace)> \
  212. do_reduce = Exec<_Reducer, false>::do_reduce; \
  213. MIDOUT_BEGIN( \
  214. megdnn_fallback_reduce_optimized, ctype, dtype, comp_type, \
  215. midout_iv(1)) { \
  216. MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \
  217. reinterpret_cast<ctype*>(src.raw_ptr()), \
  218. reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C, \
  219. workspace)); \
  220. execed = true; \
  221. } \
  222. MIDOUT_END(); \
  223. }
  224. #define DISPATCH_MODE_QUANTIZED(dtype, ctype, comp_type) \
  225. switch (param().mode) { \
  226. case Mode::MEAN: \
  227. DISPATCH_FUNC(MeanReducer, dtype, ctype, comp_type); \
  228. break; \
  229. case Mode::MAX: \
  230. DISPATCH_FUNC(maxReducer, dtype, ctype, ctype); \
  231. break; \
  232. case Mode::MIN: \
  233. DISPATCH_FUNC(minReducer, dtype, ctype, ctype); \
  234. break; \
  235. default: \
  236. break; \
  237. }
  238. #define DISPATCH_MODE_FLOAT(dtype, ctype, comp_type) \
  239. switch (param().mode) { \
  240. case Mode::MEAN: \
  241. DISPATCH_FUNC(MeanReducer, dtype, ctype, comp_type); \
  242. break; \
  243. case Mode::MAX: \
  244. DISPATCH_FUNC(maxReducer, dtype, ctype, ctype); \
  245. break; \
  246. case Mode::MIN: \
  247. DISPATCH_FUNC(minReducer, dtype, ctype, ctype); \
  248. break; \
  249. case Mode::SUM: \
  250. DISPATCH_FUNC(SumReducer, dtype, ctype, ctype); \
  251. break; \
  252. case Mode::SUM_SQR: \
  253. DISPATCH_FUNC(SumSqrReducer, dtype, ctype, ctype); \
  254. break; \
  255. case Mode::PRODUCT: \
  256. DISPATCH_FUNC(ProductReducer, dtype, ctype, ctype); \
  257. break; \
  258. default: \
  259. break; \
  260. }
  261. if (src.layout.is_contiguous() &&
  262. src.layout.dtype.category() == DTypeCategory::QUANTIZED &&
  263. param().data_type == param::Reduce::DataType::DEFAULT) {
  264. DType src_type = src.layout.dtype;
  265. if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  266. DISPATCH_MODE_QUANTIZED(dt_qint8, int8_t, int32_t)
  267. }
  268. } else if (
  269. src.layout.is_contiguous() &&
  270. src.layout.dtype.category() == DTypeCategory::FLOAT &&
  271. param().data_type == param::Reduce::DataType::DEFAULT) {
  272. DType src_type = src.layout.dtype;
  273. if (src.layout.dtype.enumv() == DTypeEnum::Float32) {
  274. DISPATCH_MODE_FLOAT(dt_float32, float, float)
  275. }
  276. }
  277. return execed;
  278. #undef DISPATCH_FUNC
  279. #undef DISPATCH_MODE_QUANTIZED
  280. #undef DISPATCH_MODE_FLOAT
  281. }
  282. } // namespace fallback
  283. } // namespace megdnn
  284. // vim: syntax=cpp.doxygen