You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. /**
  2. * \file dnn/src/common/conv_bias.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/conv_bias.h"
  13. #include "src/common/utils.h"
  14. #include "src/common/opr_delegate.h"
  15. namespace megdnn {
  16. namespace {
  17. void do_check_exec_common(
  18. ConvBiasForward* opr, const TensorLayout& src,
  19. const TensorLayout& filter, const TensorLayout& bias,
  20. const TensorLayout& z, const TensorLayout& dst,
  21. size_t workspace_in_bytes,
  22. const ConvBiasForward::PreprocessedFilter* preprocessed_filter) {
  23. megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) ||
  24. (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
  25. filter.dtype.enumv() == DTypeEnum::QuantizedS4));
  26. // check compatibility of bias's scale
  27. if (src.dtype.category() == DTypeCategory::QUANTIZED) {
  28. if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
  29. float scale_expected = mul_scale(src.dtype, filter.dtype);
  30. float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
  31. megdnn_assert(std::abs(scale_expected - scale_bias) < 1e-6,
  32. "scale_src: %f scale_filter: %f scale_bias: %f",
  33. get_scale(src.dtype), get_scale(filter.dtype),
  34. scale_bias);
  35. } else {
  36. megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32);
  37. }
  38. }
  39. megdnn_assert_contiguous(bias);
  40. auto required_workspace_in_bytes = opr->get_workspace_in_bytes(
  41. src, filter, bias, z, dst, preprocessed_filter);
  42. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes,
  43. "worksapce have size of %zu, but need %zu",
  44. workspace_in_bytes, required_workspace_in_bytes);
  45. if (bias.ndim != 0) {
  46. //! bias.layout == dst.layout failed, no assert information
  47. auto check_eq = [](const TensorLayout& bias, const TensorLayout& dst) {
  48. if (dst.dtype.category() == DTypeCategory::QUANTIZED) {
  49. return bias.eq_shape(dst);
  50. } else {
  51. return bias.eq_layout(dst);
  52. }
  53. };
  54. if (check_eq(bias, dst)) {
  55. return;
  56. }
  57. if (opr->param().format == param::ConvBias::Format::NCHW ||
  58. opr->param().format == param::ConvBias::Format::NCHW4_NCHW) {
  59. megdnn_assert(bias.shape[0] == 1);
  60. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  61. bias.to_string().c_str(), dst.to_string().c_str());
  62. megdnn_assert(bias.shape[2] == 1);
  63. megdnn_assert(bias.shape[3] == 1);
  64. } else if (param().format == param::ConvBias::Format::NHWC ||
  65. param().format == param::ConvBias::Format::NCHW4_NHWC) {
  66. megdnn_assert(bias.shape[0] == 1);
  67. megdnn_assert(bias.shape[1] == 1);
  68. megdnn_assert(bias.shape[2] == 1);
  69. megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s",
  70. bias.to_string().c_str(), dst.to_string().c_str());
  71. } else if (opr->param().format == param::ConvBias::Format::NCHW4 ||
  72. opr->param().format == param::ConvBias::Format::NCHW44 ||
  73. opr->param().format == param::ConvBias::Format::NCHW44_DOT ||
  74. opr->param().format ==
  75. param::ConvBias::Format::NCHW32_NCHW4) {
  76. megdnn_assert(bias.shape[0] == 1);
  77. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  78. bias.to_string().c_str(), dst.to_string().c_str());
  79. megdnn_assert(bias.shape[2] == 1);
  80. megdnn_assert(bias.shape[3] == 1);
  81. megdnn_assert(bias.shape[4] == 4);
  82. } else if (opr->param().format == param::ConvBias::Format::NCHW8 ||
  83. opr->param().format == param::ConvBias::Format::NCHW88) {
  84. megdnn_assert(bias.shape[0] == 1);
  85. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  86. bias.to_string().c_str(), dst.to_string().c_str());
  87. megdnn_assert(bias.shape[2] == 1);
  88. megdnn_assert(bias.shape[3] == 1);
  89. megdnn_assert(bias.shape[4] == 8);
  90. } else if (opr->param().format == param::ConvBias::Format::NCHW32 ||
  91. opr->param().format ==
  92. param::ConvBias::Format::NCHW4_NCHW32) {
  93. megdnn_assert(bias.shape[0] == 1);
  94. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  95. bias.to_string().c_str(), dst.to_string().c_str());
  96. megdnn_assert(bias.shape[2] == 1);
  97. megdnn_assert(bias.shape[3] == 1);
  98. megdnn_assert(bias.shape[4] == 32);
  99. } else if (opr->param().format == param::ConvBias::Format::CHWN4) {
  100. megdnn_assert(bias.shape[0] == dst.shape[0], "bias:%s, dst:%s",
  101. bias.to_string().c_str(), dst.to_string().c_str());
  102. megdnn_assert(bias.shape[1] == 1);
  103. megdnn_assert(bias.shape[2] == 1);
  104. megdnn_assert(bias.shape[3] == 1);
  105. megdnn_assert(bias.shape[4] == 4);
  106. } else if (opr->param().format == param::ConvBias::Format::NCHW64) {
  107. megdnn_assert(bias.shape[0] == 1);
  108. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  109. bias.to_string().c_str(), dst.to_string().c_str());
  110. megdnn_assert(bias.shape[2] == 1);
  111. megdnn_assert(bias.shape[3] == 1);
  112. megdnn_assert(bias.shape[4] == 64);
  113. } else {
  114. megdnn_assert(opr->param().format ==
  115. param::ConvBias::Format::NHWCD4);
  116. megdnn_assert(bias.shape[0] == 1);
  117. megdnn_assert(bias.shape[1] == 1);
  118. megdnn_assert(bias.shape[2] == dst.shape[2], "bias:%s, dst:%s",
  119. bias.to_string().c_str(), dst.to_string().c_str());
  120. megdnn_assert(bias.shape[3] == 1);
  121. megdnn_assert(bias.shape[4] == 4);
  122. }
  123. }
  124. if (z.ndim != 0) {
  125. megdnn_assert(opr->param().format !=
  126. param::ConvBias::Format::NCHW4_NCHW32);
  127. megdnn_assert(opr->param().format !=
  128. param::ConvBias::Format::NCHW32_NCHW4);
  129. megdnn_assert(z.dtype.enumv() == dst.dtype.enumv());
  130. megdnn_assert(z.eq_shape(dst));
  131. }
  132. }
  133. } // namespace
  134. void ConvBiasForward::deduce_dtype(DType src, DType filter, DType /* bias */,
  135. DType /* z */, DType& dst) {
  136. check_or_deduce_dtype_fwd(src, filter, dst);
  137. }
  138. void ConvBiasForward::deduce_layout(const TensorLayout& src,
  139. const TensorLayout& filter,
  140. const TensorLayout& /* bias */,
  141. const TensorLayout& /* z */,
  142. TensorLayout& dst) {
  143. deduce_layout_fwd(src, filter, dst);
  144. }
  145. ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
  146. const TensorLayout& src, const TensorLayout& filter,
  147. const TensorLayout& bias, const TensorLayout& z,
  148. const TensorLayout& dst, size_t workspace_in_bytes,
  149. const PreprocessedFilter* preprocessed_filter) {
  150. do_check_exec_common(this, src, filter, bias, z, dst, workspace_in_bytes,
  151. preprocessed_filter);
  152. auto ret = check_layout_fwd(src, filter, dst);
  153. return ret;
  154. }
  155. ConvBiasForward::CanonizedFilterMeta
  156. ConvBiasForward::check_exec_allow_noncontiguous(
  157. const TensorLayout& src, const TensorLayout& filter,
  158. const TensorLayout& bias, const TensorLayout& z,
  159. const TensorLayout& dst, size_t workspace_in_bytes,
  160. const PreprocessedFilter* preprocessed_filter) {
  161. do_check_exec_common(this, src, filter, bias, z, dst, workspace_in_bytes,
  162. preprocessed_filter);
  163. TensorLayout dst_expected;
  164. dst_expected.dtype = dst.dtype;
  165. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  166. megdnn_assert_eq_shape(dst_expected, dst);
  167. return ret;
  168. }
  169. template <typename T>
  170. struct NCHWParamTrait;
  171. template <typename T>
  172. struct NCHW44ParamTrait;
  173. std::string ConvBias::WinogradParam::to_string() const {
  174. return ssprintf("%u:%u:%u", channel_block_size, output_block_size,
  175. tile_size);
  176. }
  177. template <typename T>
  178. std::string ConvBias::algo_name(const std::string& base, const T& p,
  179. param::ConvBias::Format format) {
  180. if (format == param::ConvBias::Format::NCHW) {
  181. return ssprintf("%s:%s:%s", NCHWParamTrait<T>::category.c_str(),
  182. base.c_str(), p.to_string().c_str());
  183. } else if (format == param::ConvBias::Format::NCHW44) {
  184. return ssprintf("%s:%s:%s", NCHW44ParamTrait<T>::category.c_str(),
  185. base.c_str(), p.to_string().c_str());
  186. }
  187. megdnn_throw("Invalid format");
  188. return "";
  189. }
  190. #define FOREACH_CONV_BIAS_PARAM(cb) \
  191. cb(WinogradParam) cb(DirectParam) cb(MatmulParam) cb(DefaultParam)
  192. #define cb(pt) \
  193. template <> \
  194. struct NCHWParamTrait<ConvBias::pt> { \
  195. static const std::string category; \
  196. }; \
  197. template <> \
  198. struct NCHW44ParamTrait<ConvBias::pt> { \
  199. static const std::string category; \
  200. };
  201. FOREACH_CONV_BIAS_PARAM(cb)
  202. #undef cb
  203. #define cb(pt, ct) \
  204. const std::string NCHWParamTrait<ConvBias::pt>::category = ct; \
  205. const std::string NCHW44ParamTrait<ConvBias::pt>::category = ct
  206. cb(DirectParam, "DIRECT");
  207. cb(MatmulParam, "MATMUL");
  208. cb(DefaultParam, "DEFAULT");
  209. #undef cb
  210. const std::string NCHWParamTrait<ConvBias::WinogradParam>::category =
  211. "WINOGRAD";
  212. const std::string NCHW44ParamTrait<ConvBias::WinogradParam>::category =
  213. "WINOGRAD_NCHW44";
  214. #define cb(t) \
  215. template std::string ConvBias::algo_name<ConvBias::t>( \
  216. const std::string& base, const ConvBias::t& p, \
  217. param::ConvBias::Format format);
  218. FOREACH_CONV_BIAS_PARAM(cb)
  219. #undef cb
  220. ConvBias::WinogradParam ConvBias::parse_winograd_name(
  221. const std::string& algo_name) {
  222. ConvBias::WinogradParam ret = INVALID_WINOGRAD_PARAM;
  223. char base[128];
  224. char name[128];
  225. auto parse = [&](const std::string& algo_name,
  226. const std::string& pre) -> auto {
  227. memset(name, 0, 128);
  228. sscanf(algo_name.c_str(), "%[^:]:%[^:]:%u:%u:%u", name, base,
  229. &(ret.channel_block_size), &(ret.output_block_size),
  230. &(ret.tile_size));
  231. if (strcmp(name, pre.c_str())) {
  232. ret = INVALID_WINOGRAD_PARAM;
  233. return false;
  234. }
  235. if (ret.tile_size == 0 || ret.output_block_size == 0 ||
  236. ret.channel_block_size == 0) {
  237. ret = INVALID_WINOGRAD_PARAM;
  238. return false;
  239. }
  240. return true;
  241. };
  242. if (parse(algo_name, "WINOGRAD_NCHW44")) {
  243. return ret;
  244. } else {
  245. parse(algo_name, "WINOGRAD");
  246. return ret;
  247. }
  248. }
  249. constexpr ConvBias::WinogradParam ConvBias::INVALID_WINOGRAD_PARAM;
  250. void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args,
  251. const TensorND* conv_dst_tensor,
  252. const TensorND* dst_tensor,
  253. const TensorND* bias_tensor) {
  254. using NonlineMode = param::ConvBias::NonlineMode;
  255. switch (args.nonlineMode) {
  256. #define cb(_mode) \
  257. case NonlineMode::_mode: { \
  258. if (conv_dst_tensor->layout.dtype.category() != \
  259. DTypeCategory::QUANTIZED) { \
  260. auto nonlinear = handle->create_operator<ElemwiseForward>(); \
  261. if (bias_tensor->layout.ndim > 0) { \
  262. nonlinear->param().mode = \
  263. Elemwise::Param::Mode::FUSE_ADD_##_mode; \
  264. nonlinear->exec({*conv_dst_tensor, *bias_tensor}, \
  265. *dst_tensor); \
  266. } else { \
  267. nonlinear->param().mode = Elemwise::Param::Mode::_mode; \
  268. nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \
  269. } \
  270. } else { \
  271. auto nonlinear = handle->create_operator<ElemwiseMultiType>(); \
  272. if (bias_tensor->layout.ndim > 0) { \
  273. nonlinear->param().mode = \
  274. ElemwiseMultiType::Param::Mode::QFUSE_ADD_##_mode; \
  275. nonlinear->exec({*conv_dst_tensor, *bias_tensor}, \
  276. *dst_tensor); \
  277. } else { \
  278. nonlinear->param().mode = \
  279. ElemwiseMultiType::Param::Mode::Q##_mode; \
  280. nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \
  281. } \
  282. } \
  283. break; \
  284. }
  285. cb(RELU);
  286. cb(H_SWISH);
  287. #undef cb
  288. case NonlineMode::SIGMOID: {
  289. megdnn_assert(conv_dst_tensor->layout.dtype.category() !=
  290. DTypeCategory::QUANTIZED);
  291. auto nonlinear = handle->create_operator<ElemwiseForward>();
  292. if (bias_tensor->layout.ndim > 0) {
  293. nonlinear->param().mode =
  294. Elemwise::Param::Mode::FUSE_ADD_SIGMOID;
  295. nonlinear->exec({*conv_dst_tensor, *bias_tensor},
  296. *conv_dst_tensor);
  297. } else {
  298. nonlinear->param().mode = Elemwise::Param::Mode::SIGMOID;
  299. nonlinear->exec({*conv_dst_tensor}, *conv_dst_tensor);
  300. }
  301. break;
  302. }
  303. case NonlineMode::IDENTITY: {
  304. if (bias_tensor->layout.ndim > 0) {
  305. if (dst_tensor->layout.dtype.category() ==
  306. DTypeCategory::QUANTIZED) {
  307. auto nonlinear =
  308. handle->create_operator<ElemwiseMultiType>();
  309. nonlinear->param().mode =
  310. ElemwiseMultiType::Param::Mode::QADD;
  311. nonlinear->exec({*conv_dst_tensor, *bias_tensor},
  312. *dst_tensor);
  313. } else {
  314. auto nonlinear = handle->create_operator<Elemwise>();
  315. nonlinear->param().mode = Elemwise::Param::Mode::ADD;
  316. nonlinear->exec({*conv_dst_tensor, *bias_tensor},
  317. *dst_tensor);
  318. }
  319. } else {
  320. if (conv_dst_tensor->layout.dtype != dst_tensor->layout.dtype) {
  321. handle->create_operator<TypeCvt>()->exec({*conv_dst_tensor},
  322. *dst_tensor);
  323. }
  324. }
  325. break;
  326. }
  327. default:
  328. megdnn_assert(false);
  329. }
  330. }
  331. } // namespace megdnn
  332. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台