You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cudnn_wrapper.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. /**
  2. * \file dnn/src/cuda/cudnn_wrapper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/cuda/cudnn_wrapper.h"
  12. #include "src/common/utils.h"
  13. #include "src/cuda/utils.h"
  14. namespace {
  15. using namespace megdnn;
  16. cudnnDataType_t to_cudnn_dtype(DType type,
  17. const param::Convolution::Format format = {}) {
  18. switch (type.enumv()) {
  19. case DTypeEnum::Float32:
  20. return CUDNN_DATA_FLOAT;
  21. case DTypeEnum::Float16:
  22. return CUDNN_DATA_HALF;
  23. #if CUDNN_MAJOR >= 7
  24. case DTypeEnum::Int32:
  25. case DTypeEnum::QuantizedS32:
  26. return CUDNN_DATA_INT32;
  27. #endif
  28. #if CUDNN_MAJOR >= 6
  29. case DTypeEnum::QuantizedS8: {
  30. if (format == param::Convolution::Format::NCHW4)
  31. return CUDNN_DATA_INT8x4;
  32. #if CUDNN_VERSION >= 7500
  33. else if (format == param::Convolution::Format::NCHW32)
  34. return CUDNN_DATA_INT8x32;
  35. #endif
  36. else
  37. return CUDNN_DATA_INT8;
  38. }
  39. case DTypeEnum::Int8: {
  40. if (format == param::Convolution::Format::NCHW4)
  41. return CUDNN_DATA_INT8x4;
  42. #if CUDNN_VERSION >= 7500
  43. else if (format == param::Convolution::Format::NCHW32)
  44. return CUDNN_DATA_INT8x32;
  45. #endif
  46. else
  47. return CUDNN_DATA_INT8;
  48. }
  49. #endif
  50. default:
  51. #if CUDNN_MAJOR >= 6
  52. megdnn_throw(megdnn_mangle("dtype must be float16/float32/int8/int32"));
  53. #else
  54. megdnn_throw(megdnn_mangle("dtype must be float16/float32"));
  55. #endif
  56. }
  57. }
  58. cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format) {
  59. switch (format) {
  60. case param::Convolution::Format::NCHW:
  61. return CUDNN_TENSOR_NCHW;
  62. #if CUDNN_MAJOR >= 7
  63. case param::Convolution::Format::NCHW4:
  64. case param::Convolution::Format::NCHW32:
  65. return CUDNN_TENSOR_NCHW_VECT_C;
  66. #endif
  67. case param::Convolution::Format::NHWC:
  68. return CUDNN_TENSOR_NHWC;
  69. default:
  70. megdnn_assert_internal(0);
  71. }
  72. }
  73. } // namespace
  74. namespace megdnn {
  75. namespace cuda {
  76. cudnnDataType_t get_compute_type_fp16(
  77. param::Convolution::ComputeMode comp_mode) {
  78. using Param = param::Convolution;
  79. cudnnDataType_t compute_type;
  80. if (comp_mode == Param::ComputeMode::DEFAULT) {
  81. // TRUE_HALF_CONFIG
  82. if (is_compute_capability_required(5, 3)) {
  83. compute_type = CUDNN_DATA_HALF;
  84. } else {
  85. auto&& device_prop = current_device_prop();
  86. int major = device_prop.major, minor = device_prop.minor;
  87. MEGDNN_MARK_USED_VAR(major);
  88. MEGDNN_MARK_USED_VAR(minor);
  89. megdnn_log_warn(
  90. "TRUE_HALF_CONFIG only supported on architectures with "
  91. "true fp16 support, i.e., compute capability 5.3 and "
  92. "later (got %d.%d). Use PSEUDO_HALF_CONFIG instead",
  93. major, minor);
  94. compute_type = CUDNN_DATA_FLOAT;
  95. }
  96. } else {
  97. megdnn_assert(comp_mode == Param::ComputeMode::FLOAT32);
  98. // PSEUDO_HALF_CONFIG
  99. compute_type = CUDNN_DATA_FLOAT;
  100. }
  101. return compute_type;
  102. }
  103. TensorDesc::TensorDesc() {
  104. cudnn_check(cudnnCreateTensorDescriptor(&desc));
  105. }
  106. TensorDesc::~TensorDesc() {
  107. cudnn_check(cudnnDestroyTensorDescriptor(desc));
  108. }
  109. void TensorDesc::set(const TensorLayout& layout,
  110. const param::Convolution::Format format) {
  111. // Layout can be not contiguous; group conv needs it.
  112. // megdnn_assert_contiguous(layout);
  113. if (format == param::Convolution::Format::NCHW4 ||
  114. format == param::Convolution::Format::NCHW32)
  115. megdnn_assert_eq_size_t(layout.ndim, 5_z);
  116. else
  117. megdnn_assert_eq_size_t(layout.ndim, 4_z);
  118. size_t c_pos, spatial_pos;
  119. if (format == param::Convolution::Format::NCHW ||
  120. format == param::Convolution::Format::NCHW4 ||
  121. format == param::Convolution::Format::NCHW32) {
  122. c_pos = 1;
  123. spatial_pos = 2;
  124. } else {
  125. megdnn_assert(format == param::Convolution::Format::NHWC);
  126. c_pos = 3;
  127. spatial_pos = 1;
  128. }
  129. if (format == param::Convolution::Format::NCHW4) {
  130. megdnn_assert(layout.is_physical_contiguous());
  131. cudnn_check(cudnnSetTensor4dDescriptor(
  132. desc, to_cudnn_format(format),
  133. to_cudnn_dtype(layout.dtype, format), layout.shape[0],
  134. layout.shape[c_pos] * 4, layout.shape[spatial_pos + 0],
  135. layout.shape[spatial_pos + 1]));
  136. } else if (format == param::Convolution::Format::NCHW32) {
  137. megdnn_assert(layout.is_physical_contiguous());
  138. cudnn_check(cudnnSetTensor4dDescriptor(
  139. desc, to_cudnn_format(format),
  140. to_cudnn_dtype(layout.dtype, format), layout.shape[0],
  141. layout.shape[c_pos] * 32, layout.shape[spatial_pos + 0],
  142. layout.shape[spatial_pos + 1]));
  143. } else {
  144. cudnn_check(cudnnSetTensor4dDescriptorEx(
  145. desc, to_cudnn_dtype(layout.dtype), layout.shape[0],
  146. layout.shape[c_pos], layout.shape[spatial_pos + 0],
  147. layout.shape[spatial_pos + 1], layout.stride[0],
  148. layout.stride[c_pos], layout.stride[spatial_pos + 0],
  149. layout.stride[spatial_pos + 1]));
  150. }
  151. }
  152. template <typename Param>
  153. FilterDesc<Param>::FilterDesc() {
  154. cudnn_check(cudnnCreateFilterDescriptor(&desc));
  155. }
  156. template <typename Param>
  157. FilterDesc<Param>::~FilterDesc() {
  158. cudnn_check(cudnnDestroyFilterDescriptor(desc));
  159. }
  160. template <typename Param>
  161. void FilterDesc<Param>::set(
  162. const typename ConvolutionBase<Param>::CanonizedFilterMeta&
  163. filter_meta) {
  164. megdnn_assert(filter_meta.spatial_ndim == 2);
  165. #if CUDNN_VERSION < 7500
  166. megdnn_assert(filter_meta.dilation[0] == 1 && filter_meta.dilation[1] == 1);
  167. #endif
  168. #if CUDNN_MAJOR <= 6
  169. megdnn_assert(filter_meta.group == 1);
  170. #endif
  171. // cuDNN version 6 or below filter_meta.group always is 1.
  172. // So it is compatible for all cuDNN versions.
  173. cudnn_check(cudnnSetFilter4dDescriptor(
  174. desc, to_cudnn_dtype(filter_meta.dtype, filter_meta.format),
  175. to_cudnn_format(filter_meta.format),
  176. filter_meta.ocpg * filter_meta.group, // cudnn 6 group always be 1
  177. filter_meta.icpg, filter_meta.spatial[0], filter_meta.spatial[1]));
  178. }
  179. template class FilterDesc<param::Convolution>;
  180. template class FilterDesc<param::ConvBias>;
  181. ConvDesc::ConvDesc() {
  182. cudnn_check(cudnnCreateConvolutionDescriptor(&desc));
  183. #if CUDNN_VERSION >= 7000
  184. // cudnn enables tensor core when tensors have dataType =
  185. // CUDNN_DATA_HALF, so it should be safe to enable globally
  186. cudnn_check(cudnnSetConvolutionMathType(desc, CUDNN_TENSOR_OP_MATH));
  187. #endif
  188. }
  189. ConvDesc::~ConvDesc() {
  190. cudnn_check(cudnnDestroyConvolutionDescriptor(desc));
  191. }
  192. void ConvDesc::set(DType data_type, const param::Convolution& param,
  193. const size_t nr_group) {
  194. using Param = param::Convolution;
  195. cudnnConvolutionMode_t mode;
  196. switch (param.mode) {
  197. case Param::Mode::CROSS_CORRELATION:
  198. mode = CUDNN_CROSS_CORRELATION;
  199. break;
  200. case Param::Mode::CONVOLUTION:
  201. mode = CUDNN_CONVOLUTION;
  202. break;
  203. default:
  204. megdnn_throw(megdnn_mangle("conv mode must be conv or xcorr."));
  205. }
  206. cudnnDataType_t compute_type;
  207. MEGDNN_MARK_USED_VAR(compute_type);
  208. if (data_type.enumv() == DTypeEnum::Float32) {
  209. // FLOAT_CONFIG
  210. compute_type = CUDNN_DATA_FLOAT;
  211. } else if (data_type.enumv() == DTypeEnum::Float16) {
  212. auto comp_mode = param.compute_mode;
  213. compute_type = get_compute_type_fp16(comp_mode);
  214. #if CUDNN_MAJOR >= 7
  215. } else if (data_type.category() == DTypeCategory::INT ||
  216. data_type.category() == DTypeCategory::QUANTIZED) {
  217. compute_type = CUDNN_DATA_INT32;
  218. #endif
  219. } else {
  220. megdnn_throw(megdnn_mangle("unspport data type for conv bias"));
  221. }
  222. #if CUDNN_MAJOR >= 7
  223. cudnn_check(cudnnSetConvolutionGroupCount(desc, nr_group));
  224. #else
  225. megdnn_assert(nr_group == 1);
  226. #endif
  227. #if CUDNN_MAJOR >= 6
  228. cudnn_check(cudnnSetConvolution2dDescriptor(
  229. desc, param.pad_h, param.pad_w, param.stride_h, param.stride_w,
  230. param.dilate_h, param.dilate_w, mode, compute_type));
  231. #else
  232. cudnn_check(cudnnSetConvolution2dDescriptor(
  233. desc, param.pad_h, param.pad_w, param.stride_h, param.stride_w,
  234. param.dilate_h, param.dilate_w, mode));
  235. #endif
  236. }
  237. PoolingDesc::PoolingDesc() {
  238. cudnn_check(cudnnCreatePoolingDescriptor(&desc));
  239. }
  240. PoolingDesc::~PoolingDesc() {
  241. cudnn_check(cudnnDestroyPoolingDescriptor(desc));
  242. }
  243. void PoolingDesc::set(const param::Pooling& param) {
  244. cudnnPoolingMode_t mode;
  245. switch (param.mode) {
  246. case param::Pooling::Mode::MAX:
  247. mode = CUDNN_POOLING_MAX;
  248. break;
  249. case param::Pooling::Mode::AVERAGE:
  250. mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
  251. break;
  252. case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
  253. mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
  254. break;
  255. }
  256. cudnn_check(cudnnSetPooling2dDescriptor(
  257. desc, mode, CUDNN_NOT_PROPAGATE_NAN, param.window_h, param.window_w,
  258. param.pad_h, param.pad_w, param.stride_h, param.stride_w));
  259. }
  260. LRNDesc::LRNDesc() {
  261. cudnn_check(cudnnCreateLRNDescriptor(&desc));
  262. }
  263. LRNDesc::~LRNDesc() {
  264. cudnn_check(cudnnDestroyLRNDescriptor(desc));
  265. }
  266. void LRNDesc::set(const param::LRN& param) {
  267. megdnn_assert(param.n & 1, "n is %u", param.n);
  268. megdnn_assert(param.n >= CUDNN_LRN_MIN_N, "n is %u, CUDNN_LRN_MIN_N is %d",
  269. param.n, CUDNN_LRN_MIN_N);
  270. megdnn_assert(param.n <= CUDNN_LRN_MAX_N, "n is %u, CUDNN_LRN_MAX_N is %d",
  271. param.n, CUDNN_LRN_MAX_N);
  272. megdnn_assert(param.k >= CUDNN_LRN_MIN_K, "k is %f, CUDNN_LRN_MIN_K is %lf",
  273. param.k, CUDNN_LRN_MIN_K);
  274. megdnn_assert(param.beta >= CUDNN_LRN_MIN_BETA,
  275. "beta is %f, CUDNN_LRN_MIN_BETA is %lf", param.beta,
  276. CUDNN_LRN_MIN_BETA);
  277. // Note that alpha is divided by n in the cudnn implementation,
  278. // so we have to multiply alpha by n ahead of time.
  279. cudnn_check(cudnnSetLRNDescriptor(desc, param.n, param.alpha * param.n,
  280. param.beta, param.k));
  281. }
  282. BNParamDesc::BNParamDesc() {
  283. cudnn_check(cudnnCreateTensorDescriptor(&desc));
  284. }
  285. void BNParamDesc::set(const cudnnTensorDescriptor_t xDesc,
  286. cudnnBatchNormMode_t mode) {
  287. cudnn_check(cudnnDeriveBNTensorDescriptor(desc, xDesc, mode));
  288. }
  289. BNParamDesc::~BNParamDesc() {
  290. cudnn_check(cudnnDestroyTensorDescriptor(desc));
  291. }
  292. Tensor3DDesc::Tensor3DDesc() {
  293. cudnn_check(cudnnCreateTensorDescriptor(&desc));
  294. }
  295. Tensor3DDesc::~Tensor3DDesc() {
  296. cudnn_check(cudnnDestroyTensorDescriptor(desc));
  297. }
  298. int sc(const size_t x) {
  299. return static_cast<int>(x);
  300. }
  301. void Tensor3DDesc::set(const TensorLayout& layout, bool is_ndhwc) {
  302. megdnn_assert_eq_size_t(layout.ndim, 5_z);
  303. size_t c_pos, spatial_pos;
  304. if (is_ndhwc) {
  305. c_pos = 4;
  306. spatial_pos = 1;
  307. } else { // ncdhw
  308. c_pos = 1;
  309. spatial_pos = 2;
  310. }
  311. const int dimA[] = {sc(layout.shape[0]), sc(layout.shape[c_pos]),
  312. sc(layout.shape[spatial_pos + 0]),
  313. sc(layout.shape[spatial_pos + 1]),
  314. sc(layout.shape[spatial_pos + 2])};
  315. const int strideA[] = {sc(layout.stride[0]), sc(layout.stride[c_pos]),
  316. sc(layout.stride[spatial_pos + 0]),
  317. sc(layout.stride[spatial_pos + 1]),
  318. sc(layout.stride[spatial_pos + 2])};
  319. cudnn_check(cudnnSetTensorNdDescriptor(desc, to_cudnn_dtype(layout.dtype),
  320. 5, dimA, strideA));
  321. }
  322. Filter3DDesc::Filter3DDesc() {
  323. cudnn_check(cudnnCreateFilterDescriptor(&desc));
  324. }
  325. Filter3DDesc::~Filter3DDesc() {
  326. cudnn_check(cudnnDestroyFilterDescriptor(desc));
  327. }
  328. void Filter3DDesc::set(
  329. const Convolution3DBase::CanonizedFilterMeta& filter_meta) {
  330. megdnn_assert(filter_meta.spatial_ndim == 3);
  331. #if CUDNN_MAJOR <= 6
  332. megdnn_assert(filter_meta.group == 1);
  333. #endif
  334. // cuDNN version 6 or below filter_meta.group always is 1.
  335. // So it is compatible for all cuDNN versions.
  336. const int filterDimA[] = {
  337. sc(filter_meta.ocpg *
  338. filter_meta.group), // cudnn 6 group always be 1
  339. sc(filter_meta.icpg), sc(filter_meta.spatial[0]),
  340. sc(filter_meta.spatial[1]), sc(filter_meta.spatial[2])};
  341. cudnn_check(cudnnSetFilterNdDescriptor(
  342. desc, to_cudnn_dtype(DType::from_enum(filter_meta.dtype_enum)),
  343. CUDNN_TENSOR_NCHW, 5, filterDimA));
  344. }
  345. Conv3DDesc::Conv3DDesc() {
  346. cudnn_check(cudnnCreateConvolutionDescriptor(&desc));
  347. #if CUDNN_MAJOR >= 7
  348. // cudnn enables tensor core when tensors have dataType = CUDNN_DATA_HALF,
  349. // so it should be safe to enable globally
  350. cudnn_check(cudnnSetConvolutionMathType(desc, CUDNN_TENSOR_OP_MATH));
  351. #endif
  352. }
  353. Conv3DDesc::~Conv3DDesc() {
  354. cudnn_check(cudnnDestroyConvolutionDescriptor(desc));
  355. }
  356. void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) {
  357. cudnnConvolutionMode_t mode;
  358. switch (param.mode) {
  359. case param::Convolution3D::Mode::CROSS_CORRELATION:
  360. mode = CUDNN_CROSS_CORRELATION;
  361. break;
  362. case param::Convolution3D::Mode::CONVOLUTION:
  363. mode = CUDNN_CONVOLUTION;
  364. break;
  365. default:
  366. megdnn_throw(megdnn_mangle("conv mode must be conv or xcorr."));
  367. }
  368. #if CUDNN_MAJOR >= 7
  369. cudnn_check(cudnnSetConvolutionGroupCount(desc, nr_group));
  370. #else
  371. megdnn_assert(nr_group == 1);
  372. #endif
  373. const int padA[] = {sc(param.pad_d), sc(param.pad_h), sc(param.pad_w)},
  374. filterStrideA[] = {sc(param.stride_d), sc(param.stride_h),
  375. sc(param.stride_w)},
  376. dilationA[] = {sc(param.dilate_d), sc(param.dilate_h),
  377. sc(param.dilate_w)};
  378. // not use true half
  379. // in CUDNN_MAJOR < 6, all elements in dilA shoule be 1
  380. cudnn_check(cudnnSetConvolutionNdDescriptor(
  381. desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT));
  382. }
  383. } // namespace cuda
  384. } // namespace megdnn
  385. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台