You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cudnn.cpp 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. /**
  2. * \file dnn/src/cuda/convolution3d/backward_data/cudnn.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./algo.h"
  12. #include "src/cuda/utils.h"
  13. #include "src/cuda/cudnn_wrapper.h"
  14. #include "src/cuda/convolution3d/helper.h"
  15. using namespace megdnn;
  16. using namespace cuda;
  17. using namespace convolution3d;
  18. bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available(
  19. const SizeArgs &args) const {
  20. CUDNNBwdDataDescs D;
  21. if (!is_cudnn_supported(args.as_fwd_args()))
  22. return false;
  23. args.init_desc(D);
  24. size_t workspace_size;
  25. auto& cudnn = args.handle->cudnn();
  26. auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize(
  27. args.handle->cudnn_handle(),
  28. D.filter_desc.desc,
  29. D.diff_desc.desc,
  30. D.conv_desc.desc,
  31. D.grad_desc.desc,
  32. m_cudnn_enum,
  33. &workspace_size);
  34. return status == CUDNN_STATUS_SUCCESS;
  35. }
  36. size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes(
  37. const SizeArgs &args) const {
  38. CUDNNBwdDataDescs D;
  39. args.init_desc(D);
  40. size_t workspace_size;
  41. auto& cudnn = args.handle->cudnn();
  42. auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize(
  43. args.handle->cudnn_handle(),
  44. D.filter_desc.desc,
  45. D.diff_desc.desc,
  46. D.conv_desc.desc,
  47. D.grad_desc.desc,
  48. m_cudnn_enum,
  49. &workspace_size);
  50. megdnn_assert(status == CUDNN_STATUS_SUCCESS,
  51. "conv bwd_data get workspace failed: %s; info: %s",
  52. cudnnGetErrorString(status), args.to_string().c_str());
  53. return workspace_size;
  54. }
  55. void Convolution3DBackwardDataImpl::AlgoCUDNN::exec(
  56. const ExecArgs &args) const {
  57. CUDNNBwdDataDescs D;
  58. args.init_desc(D);
  59. float alpha = 1.0f, beta = 0.0f;
  60. auto status = cudnnConvolutionBackwardData(args.handle->cudnn_handle(),
  61. &alpha,
  62. D.filter_desc.desc, args.filter_tensor->raw_ptr,
  63. D.diff_desc.desc, args.diff_tensor->raw_ptr,
  64. D.conv_desc.desc,
  65. m_cudnn_enum,
  66. args.workspace.raw_ptr,
  67. args.workspace.size,
  68. &beta,
  69. D.grad_desc.desc,
  70. args.grad_tensor->raw_ptr);
  71. megdnn_assert(status == CUDNN_STATUS_SUCCESS,
  72. "conv bwd_data failed: %s; info: %s",
  73. cudnnGetErrorString(status), args.to_string().c_str());
  74. }
  75. void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() {
  76. for (auto&& algo : CudnnAlgoPack::conv3d_bwd_data_algos()) {
  77. cudnn.push_back(algo.first);
  78. }
  79. }
  80. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台