You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_node_env.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. /**
  2. * \file src/core/impl/comp_node_env.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/comp_node_env.h"
  12. #include "megbrain/exception.h"
  13. #include "megbrain/system.h"
  14. #include "megbrain/utils/metahelper.h"
  15. #include "megbrain/version_symbol.h"
  16. #include "megdnn/version.h"
  17. #if MGB_CUDA
  18. #include "megcore_cuda.h"
  19. #if MGB_ENABLE_DEBUG_UTIL
  20. #include <nvToolsExtCudaRt.h>
  21. #endif
  22. #endif
  23. #if MGB_ROCM
  24. #include "hcc_detail/hcc_defs_prologue.h"
  25. #include "megcore_rocm.h"
  26. #endif
  27. #if MGB_CAMBRICON
  28. #include "megcore_cambricon.h"
  29. #endif
  30. #if MGB_ATLAS
  31. #include "acl/acl.h"
  32. #include "megcore_atlas.h"
  33. #endif
  34. using namespace mgb;
  35. /* =================== MegDNNHandle =================== */
  36. MGB_TYPEINFO_OBJ_IMPL(MegDNNHandle);
  37. int MegDNNHandle::sm_default_dbg_level = 0;
  38. MegDNNHandle& MegDNNHandle::get(const CompNodeEnv& env) {
  39. auto maker = [&]() { return std::make_shared<MegDNNHandle>(env); };
  40. return env.get_user_data<MegDNNHandle>(maker);
  41. }
  42. MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) {
  43. auto megdnn_version = megdnn::get_version();
  44. mgb_throw_if(
  45. megdnn_version.major != MEGDNN_MAJOR ||
  46. megdnn_version.minor < MEGDNN_MINOR,
  47. SystemError,
  48. "incompatible megdnn version: compiled with %d.%d, get %d.%d.%d "
  49. "at runtime",
  50. MEGDNN_MAJOR, MEGDNN_MINOR, megdnn_version.major,
  51. megdnn_version.minor, megdnn_version.patch);
  52. bool init = false;
  53. #if MGB_CUDA
  54. if (env.property().type == CompNode::DeviceType::CUDA) {
  55. megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformCUDA,
  56. env.cuda_env().device, 0);
  57. megcore::createComputingHandleWithCUDAContext(&m_comp_hdl, m_dev_hdl, 0,
  58. {env.cuda_env().stream, make_async_error_info(env)});
  59. init = true;
  60. }
  61. #endif
  62. #if MGB_ROCM
  63. if (env.property().type == CompNode::DeviceType::ROCM) {
  64. megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformROCM,
  65. env.rocm_env().device, 0);
  66. megcore::createComputingHandleWithROCMContext(
  67. &m_comp_hdl, m_dev_hdl, 0,
  68. {env.rocm_env().stream, make_async_error_info(env)});
  69. init = true;
  70. }
  71. #endif
  72. #if MGB_CAMBRICON
  73. if (env.property().type == CompNode::DeviceType::CAMBRICON) {
  74. CompNodeEnv::CnrtEnv::init_status.init();
  75. megcore::createDeviceHandleWithGlobalInitStatus(
  76. &m_dev_hdl, env.cnrt_env().device, 0, true);
  77. megcore::createComputingHandleWithCambriconContext(
  78. &m_comp_hdl, m_dev_hdl, 0, {env.cnrt_env().queue});
  79. init = true;
  80. }
  81. #endif
  82. #if MGB_ATLAS
  83. if (env.property().type == CompNode::DeviceType::ATLAS) {
  84. CompNodeEnv::AtlasEnv::init_status.init();
  85. megcore::createAtlasDeviceHandleWithGlobalInitStatus(
  86. &m_dev_hdl, env.atlas_env().device, 0, true);
  87. megcore::createComputingHandleWithAtlasContext(
  88. &m_comp_hdl, m_dev_hdl, 0, {env.atlas_env().stream});
  89. init = true;
  90. }
  91. #endif
  92. if (env.property().type == CompNode::DeviceType::CPU) {
  93. megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformCPU);
  94. megcoreCreateComputingHandleWithCPUDispatcher(&m_comp_hdl, m_dev_hdl,
  95. env.cpu_env().dispatcher);
  96. init = true;
  97. }
  98. mgb_assert(init);
  99. int level = sm_default_dbg_level;
  100. if (auto set = MGB_GETENV("MGB_USE_MEGDNN_DBG")) {
  101. level = std::stol(set);
  102. mgb_log_warn("use megdnn handle with debug level: %d", level);
  103. }
  104. // handle may have been implemented when device type is cadence.
  105. if (!m_megdnn_handle) {
  106. m_megdnn_handle = megdnn::Handle::make(m_comp_hdl, level);
  107. }
  108. }
  109. MegDNNHandle::~MegDNNHandle() noexcept {
  110. m_megdnn_handle.reset();
  111. #if MGB_NEED_MEGDNN_ASYNC_ERROR
  112. m_async_error_info_devptr.reset();
  113. #endif
  114. if (m_comp_hdl) {
  115. megcoreDestroyComputingHandle(m_comp_hdl);
  116. }
  117. if (m_dev_hdl) {
  118. megcoreDestroyDeviceHandle(m_dev_hdl);
  119. }
  120. }
  121. #if MGB_NEED_MEGDNN_ASYNC_ERROR
  122. megcore::AsyncErrorInfo* MegDNNHandle::make_async_error_info(
  123. const CompNodeEnv& env) {
  124. auto cn = env.comp_node();
  125. auto del = [cn](megcore::AsyncErrorInfo* ptr) {
  126. if (ptr) {
  127. cn.free_device(ptr);
  128. }
  129. };
  130. megcore::AsyncErrorInfo zero_info{0, nullptr, "", {0, 0, 0, 0}};
  131. auto ptr = static_cast<megcore::AsyncErrorInfo*>(
  132. env.comp_node().alloc_device(sizeof(zero_info)));
  133. cn.copy_to_device(ptr, &zero_info, sizeof(zero_info));
  134. cn.sync();
  135. m_async_error_info_devptr = {ptr, del};
  136. return m_async_error_info_devptr.get();
  137. }
  138. #endif
  139. /* =================== misc =================== */
  140. #if MGB_CUDA
  141. void mgb::_on_cuda_error(const char* expr, cudaError_t err, const char* file,
  142. const char* func, int line) {
  143. mgb_throw(CudaError, "cuda error %d: %s (%s at %s:%s:%d)", int(err),
  144. cudaGetErrorString(err), expr, file, func, line);
  145. }
  146. void mgb::_on_cuda_cu_error(const char* expr, CUresult err, const char* file,
  147. const char* func, int line) {
  148. const char* msg;
  149. cuGetErrorName(err, &msg);
  150. mgb_throw(CudaError, "cuda error %d: %s (%s at %s:%s:%d)", int(err), msg,
  151. expr, file, func, line);
  152. }
  153. void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node,
  154. const ContinuationCtx<cudaStream_t>& cont) {
  155. m_comp_node = comp_node;
  156. mgb_assert(!m_user_data_container && !m_async_init_need_wait);
  157. m_cuda_env.device = dev;
  158. m_property.type = DeviceType::CUDA;
  159. MGB_CUDA_CHECK(cudaGetDeviceProperties(&m_cuda_env.device_prop, dev));
  160. {
  161. auto&& prop = m_cuda_env.device_prop;
  162. m_property.mem_alignment =
  163. std::max(prop.textureAlignment, prop.texturePitchAlignment);
  164. }
  165. std::atomic_bool tid_set{false};
  166. auto worker = [this, cont, &tid_set]() {
  167. sys::set_thread_name("async_cuda_init");
  168. m_async_init_tid = std::this_thread::get_id();
  169. tid_set.store(true);
  170. bool stream_done = false;
  171. MGB_MARK_USED_VAR(stream_done);
  172. MGB_TRY {
  173. m_cuda_env.activate();
  174. MGB_CUDA_CHECK(cudaStreamCreateWithFlags(&m_cuda_env.stream,
  175. cudaStreamNonBlocking));
  176. stream_done = true;
  177. m_user_data_container = std::make_unique<UserDataContainer>();
  178. #if MGB_ENABLE_DEBUG_UTIL
  179. nvtxNameCudaStreamA(m_cuda_env.stream,
  180. m_comp_node.to_string().c_str());
  181. #endif
  182. cont.next(m_cuda_env.stream);
  183. // megdnn is initialized here; must be placed after cont.next()
  184. // which handles comp node init
  185. mgb_assert(
  186. m_property.mem_alignment ==
  187. MegDNNHandle::get(*this).handle()->alignment_requirement());
  188. }
  189. MGB_CATCH(std::exception & exc, {
  190. mgb_log_error("async cuda init failed: %s", exc.what());
  191. if (stream_done) {
  192. cudaStreamDestroy(m_cuda_env.stream);
  193. }
  194. cont.err(exc);
  195. throw;
  196. })
  197. };
  198. m_async_init_need_wait = true;
  199. m_async_init_future = std::async(std::launch::async, worker);
  200. while (!tid_set.load())
  201. std::this_thread::yield();
  202. mgb_assert(m_async_init_tid != std::this_thread::get_id());
  203. }
  204. #endif
  205. #if MGB_ATLAS
  206. void mgb::_on_atlas_error(const char* expr, int err, const char* file,
  207. const char* func, int line) {
  208. mgb_throw(AtlasError, "atlas error %d: %s (%s at %s:%s:%d)", int(err),
  209. megcore::atlas::get_error_str(err), expr, file, func, line);
  210. }
  211. CompNodeEnv::AtlasEnv::InitStatus CompNodeEnv::AtlasEnv::init_status;
  212. void CompNodeEnv::init_atlas(CompNode comp_node, const AtlasEnv& env) {
  213. m_comp_node = comp_node;
  214. m_atlas_env = env;
  215. m_property.type = DeviceType::ATLAS;
  216. m_property.mem_alignment = 64;
  217. m_atlas_env.activate();
  218. MGB_ATLAS_CHECK(aclrtCreateStream(&m_atlas_env.stream));
  219. m_user_data_container = std::make_unique<UserDataContainer>();
  220. mgb_assert(m_property.mem_alignment ==
  221. MegDNNHandle::get(*this).handle()->alignment_requirement());
  222. }
  223. #endif
  224. #if MGB_ROCM
  225. void mgb::_on_hip_error(const char* expr, hipError_t err, const char* file,
  226. const char* func, int line) {
  227. mgb_throw(ROCmError, "rocm error %d: %s (%s at %s:%s:%d)", int(err),
  228. hipGetErrorString(err), expr, file, func, line);
  229. }
  230. void CompNodeEnv::init_rocm_async(int dev, CompNode comp_node,
  231. const ContinuationCtx<hipStream_t>& cont) {
  232. m_comp_node = comp_node;
  233. mgb_assert(!m_user_data_container && !m_async_init_need_wait);
  234. m_rocm_env.device = dev;
  235. m_property.type = DeviceType::ROCM;
  236. MGB_ROCM_CHECK(hipGetDeviceProperties(&m_rocm_env.device_prop, dev));
  237. {
  238. auto&& prop = m_rocm_env.device_prop;
  239. MGB_MARK_USED_VAR(prop);
  240. //! FIXME: no texure alignment in device property
  241. m_property.mem_alignment = 1u;
  242. }
  243. std::atomic_bool tid_set{false};
  244. auto worker = [this, cont, &tid_set]() {
  245. sys::set_thread_name("async_rocm_init");
  246. m_async_init_tid = std::this_thread::get_id();
  247. tid_set.store(true);
  248. bool stream_done = false;
  249. MGB_MARK_USED_VAR(stream_done);
  250. MGB_TRY {
  251. m_rocm_env.activate();
  252. MGB_ROCM_CHECK(hipStreamCreateWithFlags(&m_rocm_env.stream,
  253. hipStreamNonBlocking));
  254. stream_done = true;
  255. m_user_data_container = std::make_unique<UserDataContainer>();
  256. cont.next(m_rocm_env.stream);
  257. // megdnn is initialized here; must be placed after cont.next()
  258. // which handles comp node init
  259. mgb_assert(
  260. m_property.mem_alignment ==
  261. MegDNNHandle::get(*this).handle()->alignment_requirement());
  262. }
  263. MGB_CATCH(std::exception & exc, {
  264. mgb_log_error("async rocm init failed: %s", exc.what());
  265. if (stream_done) {
  266. hipStreamDestroy(m_rocm_env.stream);
  267. }
  268. cont.err(exc);
  269. throw;
  270. })
  271. };
  272. m_async_init_need_wait = true;
  273. m_async_init_future = std::async(std::launch::async, worker);
  274. while (!tid_set.load())
  275. std::this_thread::yield();
  276. mgb_assert(m_async_init_tid != std::this_thread::get_id());
  277. }
  278. #endif
  279. #if MGB_CAMBRICON
  280. const char* mgb::cnml_get_error_string(cnmlStatus_t err) {
  281. switch (err) {
  282. #define cb(_err) \
  283. case _err: \
  284. return #_err
  285. cb(CNML_STATUS_SUCCESS);
  286. cb(CNML_STATUS_NODEVICE);
  287. cb(CNML_STATUS_DOMAINERR);
  288. cb(CNML_STATUS_INVALIDARG);
  289. cb(CNML_STATUS_LENGTHERR);
  290. cb(CNML_STATUS_OUTOFRANGE);
  291. cb(CNML_STATUS_RANGEERR);
  292. cb(CNML_STATUS_OVERFLOWERR);
  293. cb(CNML_STATUS_UNDERFLOWERR);
  294. cb(CNML_STATUS_INVALIDPARAM);
  295. cb(CNML_STATUS_BADALLOC);
  296. cb(CNML_STATUS_BADTYPEID);
  297. cb(CNML_STATUS_BADCAST);
  298. cb(CNML_STATUS_UNSUPPORT);
  299. #undef cb
  300. }
  301. return "Unknown CNML error";
  302. }
  303. void mgb::_on_cnrt_error(const char* expr, cnrtRet_t err, const char* file,
  304. const char* func, int line) {
  305. mgb_throw(CnrtError, "cnrt error %d: %s (%s at %s:%s:%d)", int(err),
  306. cnrtGetErrorStr(err), expr, file, func, line);
  307. }
  308. void mgb::_on_cndev_error(const char* expr, cndevRet_t err, const char* file,
  309. const char* func, int line) {
  310. mgb_throw(CndevError, "cndev error %d: %s (%s at %s:%s:%d)", int(err),
  311. cndevGetErrorString(err), expr, file, func, line);
  312. }
  313. void mgb::_on_cnml_error(const char* expr, cnmlStatus_t err, const char* file,
  314. const char* func, int line) {
  315. mgb_throw(CnmlError, "cnml error %d: %s (%s at %s:%s:%d)", int(err),
  316. cnml_get_error_string(err), expr, file, func, line);
  317. }
  318. #endif
  319. void CompNodeEnv::init_cpu(const CpuEnv& env, CompNode comp_node) {
  320. m_comp_node = comp_node;
  321. mgb_assert(!m_user_data_container);
  322. m_property.type = DeviceType::CPU;
  323. m_cpu_env = env;
  324. m_user_data_container = std::make_unique<UserDataContainer>();
  325. m_property.mem_alignment =
  326. MegDNNHandle::get(*this).handle()->alignment_requirement();
  327. }
  328. #if MGB_CAMBRICON
  329. void CompNodeEnv::init_cnrt(int dev, CompNode comp_node,
  330. const ContinuationCtx<cnrtQueue_t>& cont) {
  331. m_comp_node = comp_node;
  332. m_cnrt_env.device = dev;
  333. m_property.type = DeviceType::CAMBRICON;
  334. MGB_CNRT_CHECK(cnrtGetDeviceInfo(&m_cnrt_env.device_info, dev));
  335. // FIXME: doc doesn't describe the aligment requirement for device memory
  336. // address
  337. m_property.mem_alignment = 1u;
  338. // ensure exception safe
  339. bool queue_created = false;
  340. MGB_MARK_USED_VAR(queue_created);
  341. MGB_TRY {
  342. m_cnrt_env.activate();
  343. MGB_CNRT_CHECK(cnrtCreateQueue(&m_cnrt_env.queue));
  344. queue_created = true;
  345. m_user_data_container = std::make_unique<UserDataContainer>();
  346. cont.next(m_cnrt_env.queue);
  347. // TODO: initialize megdnn handle
  348. mgb_assert(m_property.mem_alignment ==
  349. MegDNNHandle::get(*this).handle()->alignment_requirement());
  350. }
  351. MGB_CATCH(std::exception & exc, {
  352. mgb_log_error("cnrt init failed: %s", exc.what());
  353. if (queue_created) {
  354. MGB_CNRT_CHECK(cnrtDestroyQueue(m_cnrt_env.queue));
  355. }
  356. cont.err(exc);
  357. throw;
  358. })
  359. }
  360. CompNodeEnv::CnrtEnv::InitStatus CompNodeEnv::CnrtEnv::init_status;
  361. #endif
  362. void CompNodeEnv::fini() {
  363. ensure_async_init_finished();
  364. m_user_data_container.reset();
  365. #if MGB_CUDA
  366. if (m_property.type == DeviceType::CUDA) {
  367. m_cuda_env.activate();
  368. MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream));
  369. }
  370. #endif
  371. #if MGB_ROCM
  372. if (m_property.type == DeviceType::ROCM) {
  373. m_rocm_env.activate();
  374. MGB_ROCM_CHECK(hipStreamDestroy(m_rocm_env.stream));
  375. }
  376. #endif
  377. #if MGB_CAMBRICON
  378. if (m_property.type == DeviceType::CAMBRICON) {
  379. m_cnrt_env.activate();
  380. MGB_CNRT_CHECK(cnrtDestroyQueue(m_cnrt_env.queue));
  381. }
  382. #endif
  383. #if MGB_ATLAS
  384. if (m_property.type == DeviceType::ATLAS) {
  385. m_atlas_env.activate();
  386. MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream));
  387. }
  388. #endif
  389. }
  390. #if MGB_ENABLE_COMP_NODE_ASYNC_INIT
  391. void CompNodeEnv::wait_async_init() {
  392. if (std::this_thread::get_id() == m_async_init_tid)
  393. return;
  394. MGB_LOCK_GUARD(m_async_init_mtx);
  395. if (m_async_init_need_wait.load()) {
  396. m_async_init_future.wait();
  397. m_async_init_need_wait.store(false);
  398. m_async_init_future.get();
  399. }
  400. }
  401. #endif
  402. void CompNodeEnv::on_bad_device_type(DeviceType expected) const {
  403. mgb_throw(MegBrainError, "bad device type: expected=%d actual=%d",
  404. static_cast<int>(expected), static_cast<int>(m_property.type));
  405. }
  406. MGB_VERSION_SYMBOL3(MEGDNN, MEGDNN_MAJOR, MEGDNN_MINOR, MEGDNN_PATCH);
  407. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台