You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handle.cpp 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. /**
  2. * \file dnn/src/cuda/handle.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/handle_impl.h"
  12. #include "src/common/version_symbol.h"
  13. #include "src/cuda/handle.h"
  14. #include "src/cuda/utils.h"
  15. #include "megdnn/common.h"
  16. #include <cuda.h>
  17. #include <cstring>
  18. #define STR_HELPER(x) #x
  19. #define STR(x) STR_HELPER(x)
  20. #define CUDNN_VERSION_STR STR(CUDNN_MAJOR) "." STR(CUDNN_MINOR) "." STR(CUDNN_PATCHLEVEL)
  21. #pragma message "compile with cuDNN " CUDNN_VERSION_STR " "
  22. static_assert(!(CUDNN_MAJOR == 5 && CUDNN_MINOR == 1),
  23. "cuDNN 5.1.x series has bugs. Use 5.0.x instead.");
  24. #undef STR
  25. #undef STR_HELPER
  26. namespace megdnn {
  27. namespace cuda {
  28. HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle):
  29. HandleImplHelper(comp_handle, HandleType::CUDA)
  30. {
  31. // Get megcore device handle
  32. megcoreDeviceHandle_t dev_handle;
  33. megcoreGetDeviceHandle(comp_handle, &dev_handle);
  34. int dev_id;
  35. megcoreGetDeviceID(dev_handle, &dev_id);
  36. if (dev_id < 0) {
  37. cuda_check(cudaGetDevice(&dev_id));
  38. }
  39. m_device_id = dev_id;
  40. m_device_prop = get_device_prop(dev_id);
  41. // Get stream from MegCore computing handle.
  42. megdnn_assert(CUDNN_VERSION == cudnnGetVersion(),
  43. "cudnn version mismatch: compiled with %d; detected %zu at runtime",
  44. CUDNN_VERSION, cudnnGetVersion());
  45. #if CUDA_VERSION >= 10010
  46. megdnn_assert(cublasLtGetVersion() >= 10010,
  47. "cuda library version is too low to run cublasLt");
  48. #endif
  49. #if CUDNN_VERSION >= 8000
  50. if (!MGB_GETENV("CUDA_CACHE_PATH")) {
  51. megdnn_log_warn(R"(
  52. Cudnn8 will jit ptx code with cache. You can set
  53. CUDA_CACHE_MAXSIZE and CUDA_CACHE_PATH environment var to avoid repeat jit(very slow).
  54. For example `export CUDA_CACHE_MAXSIZE=2147483647` and `export CUDA_CACHE_PATH=/data/.cuda_cache`)");
  55. }
  56. #endif
  57. cudnn_check(cudnnCreate(&m_cudnn_handle));
  58. cublas_check(cublasCreate(&m_cublas_handle));
  59. #if CUDA_VERSION >= 10010
  60. cublas_check(cublasLtCreate(&m_cublasLt_handle));
  61. #endif
  62. megcore::getCUDAContext(comp_handle, &m_megcore_context);
  63. // Set stream for cuDNN and cublas handles.
  64. cudnn_check(cudnnSetStream(m_cudnn_handle, stream()));
  65. cublas_check(cublasSetStream(m_cublas_handle, stream()));
  66. // Note that all cublas scalars (alpha, beta) and scalar results such as dot
  67. // output resides at device side.
  68. cublas_check(cublasSetPointerMode(m_cublas_handle,
  69. CUBLAS_POINTER_MODE_DEVICE));
  70. // init const scalars
  71. cuda_check(cudaMalloc(&m_const_scalars, sizeof(ConstScalars)));
  72. ConstScalars const_scalars_val;
  73. const_scalars_val.init();
  74. cuda_check(cudaMemcpyAsync(m_const_scalars, &const_scalars_val,
  75. sizeof(ConstScalars), cudaMemcpyHostToDevice, stream()));
  76. cuda_check(cudaStreamSynchronize(stream()));
  77. // check tk1
  78. m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0);
  79. m_cusolver_handle = nullptr;
  80. }
  81. HandleImpl::~HandleImpl() noexcept {
  82. cudnn_check(cudnnDestroy(m_cudnn_handle));
  83. cublas_check(cublasDestroy(m_cublas_handle));
  84. #if CUDA_VERSION >= 10010
  85. cublas_check(cublasLtDestroy(m_cublasLt_handle));
  86. #endif
  87. if (m_cusolver_handle) {
  88. cusolver_check(cusolverDnDestroy(m_cusolver_handle));
  89. }
  90. cuda_check(cudaFree(m_const_scalars));
  91. }
  92. void HandleImpl::ConstScalars::init() {
  93. f16[0].megdnn_x = 0; f16[1].megdnn_x = 1;
  94. f32[0] = 0; f32[1] = 1;
  95. i32[0] = 0; i32[1] = 1;
  96. }
  97. size_t HandleImpl::alignment_requirement() const {
  98. auto &&prop = m_device_prop;
  99. return std::max(prop->textureAlignment, prop->texturePitchAlignment);
  100. }
  101. bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
  102. // is contiguous or can be hold by
  103. // relayout::param::try_copy_2d/try_copy_last_contig
  104. return src.is_contiguous() || src.stride[src.ndim - 1] == 1;
  105. }
  106. void HandleImpl::initialize_cusolver() {
  107. cusolver_check(cusolverDnCreate(&m_cusolver_handle));
  108. cusolver_check(cusolverDnSetStream(m_cusolver_handle, stream()));
  109. }
  110. size_t HandleImpl::image2d_pitch_alignment() const {
  111. size_t align = device_prop().texturePitchAlignment;
  112. return align;
  113. }
  114. HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
  115. return HandleVendorType::CUDA;
  116. }
  117. } // namespace cuda
  118. } // namespace megdnn
  119. MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION);
  120. MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
  121. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台