You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

heuristic_cache.cpp 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /**
  2. * \file dnn/src/common/heuristic_cache.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/heuristic_cache.h"
  13. #include "src/common/utils.h"
  14. #include "src/naive/handle.h"
  15. #if MEGDNN_WITH_CUDA
  16. #include "src/cuda/utils.h"
  17. #endif
  18. #if MEGDNN_WITH_ROCM
  19. #include "hcc_detail/hcc_defs_prologue.h"
  20. #include "megcore_rocm.h"
  21. #include "src/rocm/utils.h"
  22. #endif
  23. using namespace megdnn;
  24. HeuristicCache& HeuristicCache::instance() {
  25. static HeuristicCache ins;
  26. return ins;
  27. }
  28. HeuristicCache::KeyStorage HeuristicCache::Key::build_key_storage() const {
  29. auto&& ctg = m_category;
  30. auto&& inp = m_input;
  31. if (!m_category.empty() && !m_input.empty())
  32. return {ctg, inp};
  33. inp.reserve(sizeof(TensorLayout) * 3 * m_inp_layouts_size + m_param_size);
  34. for (size_t i = 0; i < m_inp_layouts_size; i++) {
  35. auto&& ly = m_inp_layouts_ptr[i];
  36. for (size_t j = 0; j < ly.ndim; j++) {
  37. if (j)
  38. inp.push_back(',');
  39. inp.append(std::to_string(ly.shape[j]));
  40. }
  41. inp.push_back(';');
  42. for (size_t j = 0; j < ly.ndim; j++) {
  43. if (j)
  44. inp.push_back(',');
  45. inp.append(std::to_string(ly.stride[j]));
  46. }
  47. inp.push_back(';');
  48. inp.append(ly.dtype.name());
  49. inp.push_back(';');
  50. inp.append(ly.format.to_string().c_str());
  51. inp.push_back('|');
  52. }
  53. if (m_param_size) {
  54. inp.append(reinterpret_cast<const char*>(m_param_ptr), m_param_size);
  55. }
  56. ctg = "plat:";
  57. ctg.append(std::to_string(static_cast<uint32_t>(m_handle->type())));
  58. switch (m_handle->type()) {
  59. #if MEGDNN_WITH_CUDA
  60. case Handle::HandleType::CUDA: {
  61. int cuda_rt = -1;
  62. cuda_check(cudaRuntimeGetVersion(&cuda_rt));
  63. cuda_rt /= 1000;
  64. auto&& handle = static_cast<megdnn::cuda::HandleImpl*>(m_handle);
  65. auto&& prop = handle->device_prop();
  66. ctg.append(ssprintf(";dev=%s;cap=%d.%d;runtime=%d;",
  67. prop.name, prop.major, prop.minor, cuda_rt));
  68. break;
  69. }
  70. #endif
  71. #if MEGDNN_WITH_ROCM
  72. case Handle::HandleType::ROCM: {
  73. auto&& handle = static_cast<megdnn::rocm::HandleImpl*>(m_handle);
  74. auto&& prop = handle->device_prop();
  75. int drv = -1, hip_rt = -1;
  76. hip_check(hipDriverGetVersion(&drv));
  77. hip_check(hipRuntimeGetVersion(&hip_rt));
  78. ctg.append(ssprintf(";dev=%s;cap=%d.%d,drv=%d;runtime=%d;",
  79. prop.name, prop.major, prop.minor, drv, hip_rt));
  80. break;
  81. }
  82. #endif
  83. case Handle::HandleType::FALLBACK:
  84. #if MEGDNN_X86
  85. case Handle::HandleType::X86:
  86. #endif
  87. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  88. case Handle::HandleType::ARM_COMMON:
  89. #endif
  90. #if MEGDNN_AARCH64
  91. case Handle::HandleType::AARCH64:
  92. #endif
  93. #if MEGDNN_ARMV7
  94. case Handle::HandleType::ARMV7:
  95. #endif
  96. {
  97. size_t nr_threads =
  98. static_cast<megdnn::naive::HandleImpl*>(m_handle)
  99. ->megcore_dispatcher()
  100. ->nr_threads();
  101. ctg.append(";");
  102. ctg.append(std::to_string(nr_threads));
  103. ctg.append(";");
  104. break;
  105. }
  106. default:
  107. ctg.append(";");
  108. }
  109. ctg.append(std::to_string(m_opr_type));
  110. return {ctg, inp};
  111. }
  112. void HeuristicCache::put(const Key& key, Result& result) {
  113. MEGDNN_LOCK_GUARD(m_mtx);
  114. if (result.policy.algo.valid())
  115. m_heuristic_cache[key.build_key_storage()] = result;
  116. }
  117. HeuristicCache::Result HeuristicCache::get(const Key& key) {
  118. MEGDNN_LOCK_GUARD(m_mtx);
  119. KeyStorage ks = key.build_key_storage();
  120. auto iter = m_heuristic_cache.find(ks);
  121. if (iter == m_heuristic_cache.end()) {
  122. return {};
  123. } else {
  124. return iter->second;
  125. }
  126. }
  127. void HeuristicCache::clear() {
  128. MEGDNN_LOCK_GUARD(m_mtx);
  129. m_heuristic_cache.clear();
  130. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台