You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * \file dnn/src/cuda/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/cuda/utils.cuh"
  13. #include "src/cuda/utils.h"
  14. #include "src/common/utils.h"
  15. #include "src/cuda/handle.h"
  16. #include "src/cuda/int_fastdiv.cuh"
  17. #include <mutex>
  18. using namespace megdnn;
  19. using namespace cuda;
  20. namespace {
  21. struct DevicePropRec {
  22. bool init = false;
  23. cudaDeviceProp prop;
  24. std::mutex mtx;
  25. };
  26. constexpr int MAX_NR_DEVICE = 32;
  27. DevicePropRec device_prop_rec[MAX_NR_DEVICE];
  28. const char* cublasGetErrorString(cublasStatus_t error) {
  29. switch (error) {
  30. case CUBLAS_STATUS_SUCCESS:
  31. return "CUBLAS_STATUS_SUCCESS";
  32. case CUBLAS_STATUS_NOT_INITIALIZED:
  33. return "CUBLAS_STATUS_NOT_INITIALIZED";
  34. case CUBLAS_STATUS_ALLOC_FAILED:
  35. return "CUBLAS_STATUS_ALLOC_FAILED";
  36. case CUBLAS_STATUS_INVALID_VALUE:
  37. return "CUBLAS_STATUS_INVALID_VALUE";
  38. case CUBLAS_STATUS_ARCH_MISMATCH:
  39. return "CUBLAS_STATUS_ARCH_MISMATCH";
  40. case CUBLAS_STATUS_MAPPING_ERROR:
  41. return "CUBLAS_STATUS_MAPPING_ERROR";
  42. case CUBLAS_STATUS_EXECUTION_FAILED:
  43. return "CUBLAS_STATUS_EXECUTION_FAILED";
  44. case CUBLAS_STATUS_INTERNAL_ERROR:
  45. return "CUBLAS_STATUS_INTERNAL_ERROR";
  46. case CUBLAS_STATUS_LICENSE_ERROR:
  47. return "CUBLAS_STATUS_LICENSE_ERROR";
  48. case CUBLAS_STATUS_NOT_SUPPORTED:
  49. return "CUBLAS_STATUS_NOT_SUPPORTED";
  50. }
  51. return "Unknown CUBLAS error";
  52. }
  53. } // anonymous namespace
  54. void cuda::__throw_cuda_error__(cudaError_t err, const char* msg) {
  55. auto s = ssprintf("cuda error %s(%d) occurred; expr: %s",
  56. cudaGetErrorString(err), int(err), msg);
  57. megdnn_throw(s.c_str());
  58. }
  59. void cuda::__throw_cudnn_error__(cudnnStatus_t err, const char* msg) {
  60. auto s = ssprintf("cudnn error %s(%d) occurred; expr: %s",
  61. cudnnGetErrorString(err), int(err), msg);
  62. megdnn_throw(s.c_str());
  63. }
  64. void cuda::__throw_cublas_error__(cublasStatus_t err, const char* msg) {
  65. auto s = ssprintf("cublas error %s(%d) occurred; expr: %s",
  66. cublasGetErrorString(err), int(err), msg);
  67. megdnn_throw(s.c_str());
  68. }
  69. void cuda::__throw_cusolver_error__(cusolverStatus_t err, const char* msg) {
  70. auto s = ssprintf("cusolver error %d occurred; expr: %s", int(err), msg);
  71. megdnn_throw(s.c_str());
  72. }
  73. void cuda::__throw_cuda_driver_error__(CUresult err, const char* msg) {
  74. auto s = ssprintf("cuda driver error %d occurred; expr: %s", int(err), msg);
  75. megdnn_throw(s.c_str());
  76. }
  77. void cuda::__throw_cutlass_error__(cutlass::Status err, const char* msg) {
  78. auto s = ssprintf("cutlass error %s(%d) occurred; expr: %s",
  79. cutlass::cutlassGetStatusString(err), int(err), msg);
  80. megdnn_throw(s.c_str());
  81. }
  82. void cuda::report_error(const char* msg) {
  83. megdnn_throw(msg);
  84. MEGDNN_MARK_USED_VAR(msg);
  85. }
  86. uint32_t cuda::safe_size_in_kern(size_t size) {
  87. if (!size || size > Uint32Fastdiv::MAX_DIVIDEND) {
  88. megdnn_throw(
  89. ssprintf("invalid size for element-wise kernel: %zu; "
  90. "max supported size is %u",
  91. size, Uint32Fastdiv::MAX_DIVIDEND));
  92. }
  93. return size;
  94. }
  95. const cudaDeviceProp& cuda::current_device_prop() {
  96. int dev;
  97. cuda_check(cudaGetDevice(&dev));
  98. return *(cuda::get_device_prop(dev));
  99. }
  100. const cudaDeviceProp* cuda::get_device_prop(int device) {
  101. megdnn_assert(device < MAX_NR_DEVICE, "device number too large: %d",
  102. device);
  103. megdnn_assert(device >= 0, "device number must not be negative, got %d",
  104. device);
  105. auto&& rec = device_prop_rec[device];
  106. if (!rec.init) {
  107. std::lock_guard<std::mutex> lock(rec.mtx);
  108. if (!rec.init) {
  109. cuda_check(cudaGetDeviceProperties(&rec.prop, device));
  110. rec.init = true;
  111. }
  112. }
  113. return &(rec.prop);
  114. }
  115. bool cuda::is_compute_capability_required(int major, int minor) {
  116. auto&& device_prop = cuda::current_device_prop();
  117. return device_prop.major > major ||
  118. (device_prop.major == major && device_prop.minor >= minor);
  119. }
  120. bool cuda::is_compute_capability_equalto(int major, int minor) {
  121. auto&& device_prop = cuda::current_device_prop();
  122. return device_prop.major == major && device_prop.minor == minor;
  123. }
  124. size_t cuda::max_batch_x_channel_size() {
  125. return current_device_prop().maxGridSize[2];
  126. }
  127. uint32_t cuda::param_buffer_start_address() {
  128. auto&& device_prop = current_device_prop();
  129. int cap = 10 * device_prop.major + device_prop.minor;
  130. // maxwell and pascal: 0x140
  131. if (cap >= 50 && cap < 70)
  132. return 0x140;
  133. // volta ~ ampere: 0x160
  134. else if (cap >= 70)
  135. return 0x160;
  136. megdnn_throw(
  137. ssprintf("unsupported cuda compute capability %d", cap).c_str());
  138. }
  139. const char* cuda::current_device_arch_name() {
  140. auto&& device_prop = current_device_prop();
  141. int cap = 10 * device_prop.major + device_prop.minor;
  142. if (cap >= 50 && cap < 60)
  143. return "maxwell";
  144. else if (cap >= 60 && cap < 70)
  145. return "pascal";
  146. else if (cap >= 70 && cap < 75)
  147. return "volta";
  148. else if (cap >= 75 && cap < 80)
  149. return "turing";
  150. else if (cap >= 80)
  151. return "ampere";
  152. megdnn_throw(
  153. ssprintf("unsupported cuda compute capability %d", cap).c_str());
  154. }
  155. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台