You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cudnn_wrapper.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. /**
  2. * \file dnn/src/cuda/cudnn_wrapper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/cuda/cudnn_wrapper.h"
  12. #include "src/common/utils.h"
  13. #include "src/cuda/utils.h"
  14. namespace {
  15. using namespace megdnn;
  16. cudnnDataType_t to_cudnn_dtype(
  17. DType type, const param::Convolution::Format format = {}) {
  18. switch (type.enumv()) {
  19. case DTypeEnum::Float32:
  20. return CUDNN_DATA_FLOAT;
  21. case DTypeEnum::Float16:
  22. return CUDNN_DATA_HALF;
  23. #if CUDNN_MAJOR >= 7
  24. case DTypeEnum::Int32:
  25. case DTypeEnum::QuantizedS32:
  26. return CUDNN_DATA_INT32;
  27. #endif
  28. #if CUDNN_MAJOR >= 6
  29. case DTypeEnum::QuantizedS8: {
  30. if (format == param::Convolution::Format::NCHW4)
  31. return CUDNN_DATA_INT8x4;
  32. #if CUDNN_VERSION >= 7500
  33. else if (format == param::Convolution::Format::NCHW32)
  34. return CUDNN_DATA_INT8x32;
  35. #endif
  36. else
  37. return CUDNN_DATA_INT8;
  38. }
  39. case DTypeEnum::Int8: {
  40. if (format == param::Convolution::Format::NCHW4)
  41. return CUDNN_DATA_INT8x4;
  42. #if CUDNN_VERSION >= 7500
  43. else if (format == param::Convolution::Format::NCHW32)
  44. return CUDNN_DATA_INT8x32;
  45. #endif
  46. else
  47. return CUDNN_DATA_INT8;
  48. }
  49. #endif
  50. default:
  51. #if CUDNN_MAJOR >= 6
  52. megdnn_throw("dtype must be float16/float32/int8/int32");
  53. #else
  54. megdnn_throw("dtype must be float16/float32");
  55. #endif
  56. }
  57. }
  58. cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format) {
  59. switch (format) {
  60. case param::Convolution::Format::NCHW:
  61. return CUDNN_TENSOR_NCHW;
  62. #if CUDNN_MAJOR >= 7
  63. case param::Convolution::Format::NCHW4:
  64. case param::Convolution::Format::NCHW32:
  65. return CUDNN_TENSOR_NCHW_VECT_C;
  66. #endif
  67. case param::Convolution::Format::NHWC:
  68. return CUDNN_TENSOR_NHWC;
  69. default:
  70. megdnn_assert_internal(0);
  71. }
  72. }
  73. } // namespace
  74. namespace megdnn {
  75. namespace cuda {
  76. cudnnDataType_t get_compute_type_fp16(param::Convolution::ComputeMode comp_mode) {
  77. using Param = param::Convolution;
  78. cudnnDataType_t compute_type;
  79. if (comp_mode == Param::ComputeMode::DEFAULT) {
  80. // TRUE_HALF_CONFIG
  81. if (is_compute_capability_required(5, 3)) {
  82. compute_type = CUDNN_DATA_HALF;
  83. } else {
  84. auto&& device_prop = current_device_prop();
  85. int major = device_prop.major, minor = device_prop.minor;
  86. MEGDNN_MARK_USED_VAR(major);
  87. MEGDNN_MARK_USED_VAR(minor);
  88. megdnn_log_warn(
  89. "TRUE_HALF_CONFIG only supported on architectures with "
  90. "true fp16 support, i.e., compute capability 5.3 and "
  91. "later (got %d.%d). Use PSEUDO_HALF_CONFIG instead",
  92. major, minor);
  93. compute_type = CUDNN_DATA_FLOAT;
  94. }
  95. } else {
  96. megdnn_assert(comp_mode == Param::ComputeMode::FLOAT32);
  97. // PSEUDO_HALF_CONFIG
  98. compute_type = CUDNN_DATA_FLOAT;
  99. }
  100. return compute_type;
  101. }
  102. TensorDesc::TensorDesc() {
  103. cudnn_check(cudnnCreateTensorDescriptor(&desc));
  104. }
  105. TensorDesc::~TensorDesc() {
  106. cudnn_check(cudnnDestroyTensorDescriptor(desc));
  107. }
  108. void TensorDesc::set(
  109. const TensorLayout& layout, const param::Convolution::Format format) {
  110. // Layout can be not contiguous; group conv needs it.
  111. // megdnn_assert_contiguous(layout);
  112. if (format == param::Convolution::Format::NCHW4 ||
  113. format == param::Convolution::Format::NCHW32)
  114. megdnn_assert_eq_size_t(layout.ndim, 5_z);
  115. else
  116. megdnn_assert_eq_size_t(layout.ndim, 4_z);
  117. size_t c_pos, spatial_pos;
  118. if (format == param::Convolution::Format::NCHW ||
  119. format == param::Convolution::Format::NCHW4 ||
  120. format == param::Convolution::Format::NCHW32) {
  121. c_pos = 1;
  122. spatial_pos = 2;
  123. } else {
  124. megdnn_assert(format == param::Convolution::Format::NHWC);
  125. c_pos = 3;
  126. spatial_pos = 1;
  127. }
  128. if (format == param::Convolution::Format::NCHW4) {
  129. megdnn_assert(layout.is_physical_contiguous());
  130. cudnn_check(cudnnSetTensor4dDescriptor(
  131. desc, to_cudnn_format(format), to_cudnn_dtype(layout.dtype, format),
  132. layout.shape[0], layout.shape[c_pos] * 4, layout.shape[spatial_pos + 0],
  133. layout.shape[spatial_pos + 1]));
  134. } else if (format == param::Convolution::Format::NCHW32) {
  135. megdnn_assert(layout.is_physical_contiguous());
  136. cudnn_check(cudnnSetTensor4dDescriptor(
  137. desc, to_cudnn_format(format), to_cudnn_dtype(layout.dtype, format),
  138. layout.shape[0], layout.shape[c_pos] * 32,
  139. layout.shape[spatial_pos + 0], layout.shape[spatial_pos + 1]));
  140. } else {
  141. cudnn_check(cudnnSetTensor4dDescriptorEx(
  142. desc, to_cudnn_dtype(layout.dtype), layout.shape[0],
  143. layout.shape[c_pos], layout.shape[spatial_pos + 0],
  144. layout.shape[spatial_pos + 1], layout.stride[0], layout.stride[c_pos],
  145. layout.stride[spatial_pos + 0], layout.stride[spatial_pos + 1]));
  146. }
  147. }
  148. std::string TensorDesc::to_string() {
  149. cudnnDataType_t data_type;
  150. int n;
  151. int c;
  152. int h;
  153. int w;
  154. int n_stride;
  155. int c_stride;
  156. int h_stride;
  157. int w_stride;
  158. cudnn_check(cudnnGetTensor4dDescriptor(
  159. desc, &data_type, &n, &c, &h, &w, &n_stride, &c_stride, &h_stride,
  160. &w_stride));
  161. return ssprintf(
  162. "<dtype_%d, %d,%d,%d,%d(%d,%d,%d,%d)>", data_type, n, c, h, w, n_stride,
  163. c_stride, h_stride, w_stride);
  164. }
  165. template <typename Param>
  166. FilterDesc<Param>::FilterDesc() {
  167. cudnn_check(cudnnCreateFilterDescriptor(&desc));
  168. }
  169. template <typename Param>
  170. FilterDesc<Param>::~FilterDesc() {
  171. cudnn_check(cudnnDestroyFilterDescriptor(desc));
  172. }
  173. template <typename Param>
  174. std::string FilterDesc<Param>::to_string() {
  175. cudnnDataType_t data_type;
  176. cudnnTensorFormat_t format;
  177. int k;
  178. int c;
  179. int h;
  180. int w;
  181. cudnn_check(cudnnGetFilter4dDescriptor(desc, &data_type, &format, &k, &c, &h, &w));
  182. return ssprintf(
  183. "<dtype_%d, format_%d, %d,%d,%d,%d>", data_type, format, k, c, h, w);
  184. }
  185. template <typename Param>
  186. void FilterDesc<Param>::set(
  187. const typename ConvolutionBase<Param>::CanonizedFilterMeta& filter_meta) {
  188. megdnn_assert(filter_meta.spatial_ndim == 2);
  189. #if CUDNN_VERSION < 7500
  190. megdnn_assert(filter_meta.dilation[0] == 1 && filter_meta.dilation[1] == 1);
  191. #endif
  192. #if CUDNN_MAJOR <= 6
  193. megdnn_assert(filter_meta.group == 1);
  194. #endif
  195. auto filter_format = filter_meta.format;
  196. if (filter_format == param::ConvBias::Format::NCHW4_NCHW) {
  197. filter_format = param::ConvBias::Format::NCHW4;
  198. }
  199. // cuDNN version 6 or below filter_meta.group always is 1.
  200. // So it is compatible for all cuDNN versions.
  201. cudnn_check(cudnnSetFilter4dDescriptor(
  202. desc, to_cudnn_dtype(filter_meta.dtype, filter_format),
  203. to_cudnn_format(filter_format),
  204. filter_meta.ocpg * filter_meta.group, // cudnn 6 group always be 1
  205. filter_meta.icpg, filter_meta.spatial[0], filter_meta.spatial[1]));
  206. }
  207. template class FilterDesc<param::Convolution>;
  208. template class FilterDesc<param::ConvBias>;
  209. ConvDesc::ConvDesc() {
  210. cudnn_check(cudnnCreateConvolutionDescriptor(&desc));
  211. #if CUDNN_VERSION >= 7000
  212. // cudnn enables tensor core when tensors have dataType =
  213. // CUDNN_DATA_HALF, so it should be safe to enable globally
  214. cudnn_check(cudnnSetConvolutionMathType(desc, CUDNN_TENSOR_OP_MATH));
  215. #endif
  216. }
  217. ConvDesc::~ConvDesc() {
  218. cudnn_check(cudnnDestroyConvolutionDescriptor(desc));
  219. }
  220. void ConvDesc::set(
  221. DType data_type, const param::Convolution& param, const size_t nr_group) {
  222. using Param = param::Convolution;
  223. cudnnConvolutionMode_t mode;
  224. switch (param.mode) {
  225. case Param::Mode::CROSS_CORRELATION:
  226. mode = CUDNN_CROSS_CORRELATION;
  227. break;
  228. case Param::Mode::CONVOLUTION:
  229. mode = CUDNN_CONVOLUTION;
  230. break;
  231. default:
  232. megdnn_throw("conv mode must be conv or xcorr.");
  233. }
  234. cudnnDataType_t compute_type;
  235. MEGDNN_MARK_USED_VAR(compute_type);
  236. if (data_type.enumv() == DTypeEnum::Float32) {
  237. // FLOAT_CONFIG
  238. compute_type = CUDNN_DATA_FLOAT;
  239. } else if (data_type.enumv() == DTypeEnum::Float16) {
  240. auto comp_mode = param.compute_mode;
  241. compute_type = get_compute_type_fp16(comp_mode);
  242. #if CUDNN_MAJOR >= 7
  243. } else if (
  244. data_type.category() == DTypeCategory::INT ||
  245. data_type.category() == DTypeCategory::QUANTIZED) {
  246. compute_type = CUDNN_DATA_INT32;
  247. #endif
  248. } else {
  249. megdnn_throw("unspport data type for conv bias");
  250. }
  251. #if CUDNN_MAJOR >= 7
  252. cudnn_check(cudnnSetConvolutionGroupCount(desc, nr_group));
  253. #else
  254. megdnn_assert(nr_group == 1);
  255. #endif
  256. #if CUDNN_MAJOR >= 6
  257. cudnn_check(cudnnSetConvolution2dDescriptor(
  258. desc, param.pad_h, param.pad_w, param.stride_h, param.stride_w,
  259. param.dilate_h, param.dilate_w, mode, compute_type));
  260. #else
  261. cudnn_check(cudnnSetConvolution2dDescriptor(
  262. desc, param.pad_h, param.pad_w, param.stride_h, param.stride_w,
  263. param.dilate_h, param.dilate_w, mode));
  264. #endif
  265. }
  266. LRNDesc::LRNDesc() {
  267. cudnn_check(cudnnCreateLRNDescriptor(&desc));
  268. }
  269. LRNDesc::~LRNDesc() {
  270. cudnn_check(cudnnDestroyLRNDescriptor(desc));
  271. }
  272. void LRNDesc::set(const param::LRN& param) {
  273. megdnn_assert(param.n & 1, "n is %u", param.n);
  274. megdnn_assert(
  275. param.n >= CUDNN_LRN_MIN_N, "n is %u, CUDNN_LRN_MIN_N is %d", param.n,
  276. CUDNN_LRN_MIN_N);
  277. megdnn_assert(
  278. param.n <= CUDNN_LRN_MAX_N, "n is %u, CUDNN_LRN_MAX_N is %d", param.n,
  279. CUDNN_LRN_MAX_N);
  280. megdnn_assert(
  281. param.k >= CUDNN_LRN_MIN_K, "k is %f, CUDNN_LRN_MIN_K is %lf", param.k,
  282. CUDNN_LRN_MIN_K);
  283. megdnn_assert(
  284. param.beta >= CUDNN_LRN_MIN_BETA, "beta is %f, CUDNN_LRN_MIN_BETA is %lf",
  285. param.beta, CUDNN_LRN_MIN_BETA);
  286. // Note that alpha is divided by n in the cudnn implementation,
  287. // so we have to multiply alpha by n ahead of time.
  288. cudnn_check(cudnnSetLRNDescriptor(
  289. desc, param.n, param.alpha * param.n, param.beta, param.k));
  290. }
  291. BNParamDesc::BNParamDesc() {
  292. cudnn_check(cudnnCreateTensorDescriptor(&desc));
  293. }
  294. void BNParamDesc::set(const cudnnTensorDescriptor_t xDesc, cudnnBatchNormMode_t mode) {
  295. cudnn_check(cudnnDeriveBNTensorDescriptor(desc, xDesc, mode));
  296. }
  297. BNParamDesc::~BNParamDesc() {
  298. cudnn_check(cudnnDestroyTensorDescriptor(desc));
  299. }
  300. Tensor3DDesc::Tensor3DDesc() {
  301. cudnn_check(cudnnCreateTensorDescriptor(&desc));
  302. }
  303. Tensor3DDesc::~Tensor3DDesc() {
  304. cudnn_check(cudnnDestroyTensorDescriptor(desc));
  305. }
  306. int sc(const size_t x) {
  307. return static_cast<int>(x);
  308. }
  309. void Tensor3DDesc::set(const TensorLayout& layout, bool is_ndhwc) {
  310. megdnn_assert_eq_size_t(layout.ndim, 5_z);
  311. size_t c_pos, spatial_pos;
  312. if (is_ndhwc) {
  313. c_pos = 4;
  314. spatial_pos = 1;
  315. } else { // ncdhw
  316. c_pos = 1;
  317. spatial_pos = 2;
  318. }
  319. const int dimA[] = {
  320. sc(layout.shape[0]), sc(layout.shape[c_pos]),
  321. sc(layout.shape[spatial_pos + 0]), sc(layout.shape[spatial_pos + 1]),
  322. sc(layout.shape[spatial_pos + 2])};
  323. const int strideA[] = {
  324. sc(layout.stride[0]), sc(layout.stride[c_pos]),
  325. sc(layout.stride[spatial_pos + 0]), sc(layout.stride[spatial_pos + 1]),
  326. sc(layout.stride[spatial_pos + 2])};
  327. cudnn_check(cudnnSetTensorNdDescriptor(
  328. desc, to_cudnn_dtype(layout.dtype), 5, dimA, strideA));
  329. }
  330. Filter3DDesc::Filter3DDesc() {
  331. cudnn_check(cudnnCreateFilterDescriptor(&desc));
  332. }
  333. Filter3DDesc::~Filter3DDesc() {
  334. cudnn_check(cudnnDestroyFilterDescriptor(desc));
  335. }
  336. void Filter3DDesc::set(const Convolution3DBase::CanonizedFilterMeta& filter_meta) {
  337. megdnn_assert(filter_meta.spatial_ndim == 3);
  338. #if CUDNN_MAJOR <= 6
  339. megdnn_assert(filter_meta.group == 1);
  340. #endif
  341. // cuDNN version 6 or below filter_meta.group always is 1.
  342. // So it is compatible for all cuDNN versions.
  343. const int filterDimA[] = {
  344. sc(filter_meta.ocpg * filter_meta.group), // cudnn 6 group always be 1
  345. sc(filter_meta.icpg), sc(filter_meta.spatial[0]),
  346. sc(filter_meta.spatial[1]), sc(filter_meta.spatial[2])};
  347. cudnn_check(cudnnSetFilterNdDescriptor(
  348. desc, to_cudnn_dtype(DType::from_enum(filter_meta.dtype_enum)),
  349. CUDNN_TENSOR_NCHW, 5, filterDimA));
  350. }
  351. Conv3DDesc::Conv3DDesc() {
  352. cudnn_check(cudnnCreateConvolutionDescriptor(&desc));
  353. #if CUDNN_MAJOR >= 7
  354. // cudnn enables tensor core when tensors have dataType = CUDNN_DATA_HALF,
  355. // so it should be safe to enable globally
  356. cudnn_check(cudnnSetConvolutionMathType(desc, CUDNN_TENSOR_OP_MATH));
  357. #endif
  358. }
  359. Conv3DDesc::~Conv3DDesc() {
  360. cudnn_check(cudnnDestroyConvolutionDescriptor(desc));
  361. }
  362. void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) {
  363. cudnnConvolutionMode_t mode;
  364. switch (param.mode) {
  365. case param::Convolution3D::Mode::CROSS_CORRELATION:
  366. mode = CUDNN_CROSS_CORRELATION;
  367. break;
  368. case param::Convolution3D::Mode::CONVOLUTION:
  369. mode = CUDNN_CONVOLUTION;
  370. break;
  371. default:
  372. megdnn_throw("conv mode must be conv or xcorr.");
  373. }
  374. #if CUDNN_MAJOR >= 7
  375. cudnn_check(cudnnSetConvolutionGroupCount(desc, nr_group));
  376. #else
  377. megdnn_assert(nr_group == 1);
  378. #endif
  379. const int padA[] = {sc(param.pad_d), sc(param.pad_h), sc(param.pad_w)},
  380. filterStrideA[] =
  381. {sc(param.stride_d), sc(param.stride_h), sc(param.stride_w)},
  382. dilationA[] = {
  383. sc(param.dilate_d), sc(param.dilate_h), sc(param.dilate_w)};
  384. // not use true half
  385. // in CUDNN_MAJOR < 6, all elements in dilA shoule be 1
  386. cudnn_check(cudnnSetConvolutionNdDescriptor(
  387. desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT));
  388. }
  389. ////////////////////////// CudnnAlgoPack //////////////////////////
  390. #define V1(v) #v
  391. #define V(v) V1(v)
  392. #define DEF_NAME(NAME) \
  393. #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL)
  394. #define DEF_ALGO(NAME, PROD1, PROD2) \
  395. { \
  396. NAME, { DEF_NAME(NAME), PROD1, PROD2 } \
  397. }
  398. #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
  399. #pragma message "not latest cudnn"
  400. #endif
  401. const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr>
  402. CudnnAlgoPack::conv_bwd_data_algos() {
  403. static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr>
  404. algos = {
  405. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false),
  406. #if CUDNN_VERSION == 8004
  407. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, true),
  408. #else
  409. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false),
  410. #endif
  411. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true, true),
  412. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true),
  413. #if CUDNN_MAJOR >= 5
  414. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true, true),
  415. #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
  416. DEF_ALGO(
  417. CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true, false),
  418. #endif
  419. #endif
  420. };
  421. return algos;
  422. }
  423. const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr>
  424. CudnnAlgoPack::conv_bwd_flt_algos() {
  425. static const std::unordered_map<
  426. cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr>
  427. algos =
  428. { DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false),
  429. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false),
  430. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true, true),
  431. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false),
  432. #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1)
  433. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true, false),
  434. #if CUDNN_MAJOR >= 6
  435. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true, true),
  436. #endif
  437. #endif
  438. };
  439. return algos;
  440. }
  441. const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> CudnnAlgoPack::
  442. conv_fwd_algos() {
  443. static const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr>
  444. algos = {
  445. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false),
  446. #if CUDNN_VERSION == 8004
  447. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true),
  448. #else
  449. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false),
  450. #endif
  451. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true, false),
  452. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true, false),
  453. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true, true),
  454. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true),
  455. #if CUDNN_MAJOR >= 5
  456. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true, false),
  457. #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
  458. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true, false),
  459. #endif
  460. #endif
  461. };
  462. return algos;
  463. }
  464. const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr>
  465. CudnnAlgoPack::conv3d_bwd_data_algos() {
  466. static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr>
  467. algos = {
  468. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false),
  469. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false),
  470. DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true),
  471. };
  472. return algos;
  473. } // namespace cuda
  474. const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr>
  475. CudnnAlgoPack::conv3d_bwd_flt_algos() {
  476. #pragma message \
  477. "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc"
  478. static const std::unordered_map<
  479. cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr>
  480. algos = {
  481. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false),
  482. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false),
  483. DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false),
  484. };
  485. return algos;
  486. }
  487. const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> CudnnAlgoPack::
  488. conv3d_fwd_algos() {
  489. static const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr>
  490. algos = {
  491. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false),
  492. #if CUDNN_VERSION == 8004
  493. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true),
  494. #else
  495. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false),
  496. #endif
  497. DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true),
  498. };
  499. return algos;
  500. }
  501. #undef DEF_ALGO
  502. #undef DEF_NAME
  503. #undef V
  504. #undef V1
  505. } // namespace cuda
  506. } // namespace megdnn
  507. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台