You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compiler_cuda.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. /**
  2. * \file src/jit/impl/halide/compiler_cuda.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./compiler_cuda.h"
  12. #if MGB_JIT_HALIDE && MGB_CUDA
  13. #include "../nvrtc/compiler_cuda.h"
  14. #include "./ast_hl.h"
  15. #include "megbrain/common.h"
  16. #include "megbrain/comp_node_env.h"
  17. #include "megbrain/jit/utils.h"
  18. #include "megbrain/utils/timer.h"
  19. #include <HalideRuntimeCuda.h>
  20. using namespace mgb;
  21. using namespace jit;
  22. using namespace Halide;
  23. /* =================== HalideCudaTargetTrait ==================== */
  24. struct HalideCudaTargetTrait::UserData
  25. : public HalideExecutable::TargetTraitUserData {
  26. DeviceProp dev_prop; //!< dev prop used to generate schedule the func
  27. Halide::Pipeline pipeline;
  28. std::mutex mtx;
  29. };
  30. HalideCudaTargetTrait::FeatureSet HalideCudaTargetTrait::features(
  31. CompNode comp_node) const {
  32. FeatureSet set;
  33. set.set(Target::CUDA);
  34. auto&& prop = CompNodeEnv::from_comp_node(comp_node).cuda_env().device_prop;
  35. auto in = [ver = prop.major * 10 + prop.minor](int low, int high) {
  36. return ver >= low && ver < high;
  37. };
  38. if (in(30, 32)) {
  39. set.set(Target::CUDACapability30);
  40. } else if (in(32, 35)) {
  41. set.set(Target::CUDACapability32);
  42. } else if (in(35, 40)) {
  43. set.set(Target::CUDACapability35);
  44. } else if (in(50, 61)) {
  45. set.set(Target::CUDACapability50);
  46. } else if (in(61, 70)) {
  47. set.set(Target::CUDACapability61);
  48. } else {
  49. mgb_log_warn("cuda capability(%d.%d) not support for Halide, using compute capability 6.1",
  50. prop.major, prop.minor);
  51. set.set(Target::CUDACapability61);
  52. }
  53. return set;
  54. }
  55. HalideCudaTargetTrait::FunctionHandle HalideCudaTargetTrait::compile_and_load(
  56. CompNode comp_node, Halide::Target target,
  57. const HalideExecutable& hl_exec) {
  58. auto&& dev_prop = get_dev_prop(comp_node);
  59. auto func_name = next_kernel_name();
  60. auto&& helper = ExecutableHelper::get();
  61. auto make_ud =
  62. [&]() -> std::unique_ptr<HalideExecutable::TargetTraitUserData> {
  63. auto ret = std::make_unique<UserData>();
  64. ret->dev_prop = dev_prop;
  65. ret->pipeline =
  66. gen_halide_pipeline_schedule(hl_exec.halide_output(), dev_prop);
  67. return ret;
  68. };
  69. auto ud = static_cast<UserData*>(user_data(hl_exec, make_ud));
  70. // since halide func and schedule are coupled, we need to copy the func to
  71. // use a different schedule
  72. mgb_throw_if(dev_prop.max_threads_per_block !=
  73. ud->dev_prop.max_threads_per_block,
  74. InternalError,
  75. "halide on multiple devices with different "
  76. "max_threads_per_block is currently not supported");
  77. auto&& pipeline = ud->pipeline;
  78. auto halide_inputs = hl_exec.halide_inputs();
  79. RealTimer timer;
  80. {
  81. // this compile seems not thread safe
  82. MGB_LOCK_GUARD(ud->mtx);
  83. pipeline.compile_to_object(helper.realpath(func_name + ".o"),
  84. halide_inputs, func_name, target);
  85. if (ExecutableHelper::keep_interm()) {
  86. pipeline.compile_to_lowered_stmt(
  87. helper.realpath(func_name + ".stmt"), halide_inputs, Text,
  88. target);
  89. }
  90. }
  91. auto time_compile = timer.get_msecs_reset();
  92. FunctionHandle ret;
  93. ret.init_uctx_map();
  94. auto obj_name = func_name + ".o";
  95. ret.dl_handle = helper.link_and_load(
  96. {HalideCudaCompiler::cuda_runtime_lib(), obj_name},
  97. func_name + ".so");
  98. helper.remove_interm(obj_name);
  99. helper.resolve_func(ret.get_device_interface, ret.dl_handle,
  100. "halide_cuda_device_interface");
  101. helper.resolve_func(ret.execute, ret.dl_handle, func_name + "_argv");
  102. helper.resolve_func(ret.device_release, ret.dl_handle,
  103. "halide_cuda_device_release");
  104. auto time_link = timer.get_msecs_reset();
  105. mgb_log("Halide CUDA JIT: compile %s for %s: time_compile=%.3fms "
  106. "time_link=%.3fms",
  107. func_name.c_str(), comp_node.to_string().c_str(), time_compile,
  108. time_link);
  109. return ret;
  110. }
  111. void* HalideCudaTargetTrait::get_user_context(CompNode comp_node) {
  112. return &(get_dev_prop(comp_node).ctx);
  113. }
  114. HalideCudaTargetTrait::DeviceProp& HalideCudaTargetTrait::get_dev_prop(
  115. CompNode comp_node) {
  116. MGB_LOCK_GUARD(m_mtx);
  117. auto&& ret = m_cn2prop[comp_node];
  118. if (ret.max_threads_per_block == -1) {
  119. auto&& env = CompNodeEnv::from_comp_node(comp_node).cuda_env();
  120. comp_node.activate();
  121. MGB_CUDA_CU_CHECK(cuCtxGetCurrent(&(ret.ctx.ctx)));
  122. ret.ctx.strm = env.stream;
  123. ret.max_threads_per_block = env.device_prop.maxThreadsPerBlock;
  124. }
  125. return ret;
  126. }
  127. Halide::Pipeline HalideCudaTargetTrait::gen_halide_pipeline_schedule(
  128. const ast_hl::AstNodePtr& dst_output, const DeviceProp& device_prop) {
  129. #if 1
  130. using namespace ast_hl;
  131. // traverse inline
  132. std::unordered_set<AstNodePtr> visited;
  133. std::queue<AstNodePtr> q;
  134. for (auto inp : dst_output->m_inputs) {
  135. q.push(inp);
  136. }
  137. std::unordered_set<ReduceOp*> reduce_set;
  138. while (!q.empty()) {
  139. auto top = q.front();
  140. if (visited.count(top)) {
  141. q.pop();
  142. continue;
  143. }
  144. for (auto inp : top->m_inputs) {
  145. q.push(inp);
  146. }
  147. if (!top->same_type<InputDevValueOp>() && !top->same_type<ReduceOp>() &&
  148. !top->same_type<InputHostValueShapeOp>() &&
  149. !top->same_type<BroadcastOp>()) {
  150. top->m_func.compute_inline();
  151. }
  152. if (auto reduce_opr = try_cast_as_op<ReduceOp>(top.get())) {
  153. reduce_set.insert(reduce_opr);
  154. }
  155. visited.insert(top);
  156. q.pop();
  157. }
  158. std::vector<Func> outputs;
  159. auto process_reduce = [&](Func f, Var tx) {
  160. for (auto&& reduce_opr : reduce_set) {
  161. if (reduce_opr->m_comp.defined()) {
  162. reduce_opr->m_comp.compute_at(f, tx);
  163. }
  164. reduce_opr->m_func.compute_at(f, tx);
  165. }
  166. };
  167. auto schedule_elemwise_like = [&process_reduce, &outputs,
  168. &device_prop](const AstNodePtr& output) {
  169. auto& f = output->m_func;
  170. auto vars = f.args();
  171. auto&& layout = output->m_layout;
  172. size_t total_nr_elems = layout.total_nr_elems();
  173. mgb_assert(vars.size() == layout.ndim);
  174. for (int i = layout.ndim - 1; i >= 0; i--) {
  175. f.bound(vars[layout.ndim - 1 - i], 0, static_cast<int>(layout[i]));
  176. }
  177. Var fused = vars[0];
  178. for (size_t i = 1; i < vars.size(); i++) {
  179. output->m_func.fuse(fused, vars[i], fused);
  180. }
  181. const int max_blocks = 65536;
  182. const int max_threads_num = device_prop.max_threads_per_block;
  183. bool need_block_split =
  184. total_nr_elems >
  185. static_cast<size_t>(max_blocks * max_threads_num);
  186. const int bt = max_blocks * max_threads_num;
  187. if (need_block_split) {
  188. Var xo, xi;
  189. Var bx, tx;
  190. f.split(fused, xo, xi, bt, TailStrategy::GuardWithIf);
  191. f.split(xi, bx, tx, Expr{max_threads_num},
  192. TailStrategy::GuardWithIf);
  193. f.reorder(xo, tx, bx);
  194. f.unroll(xo);
  195. f.gpu_threads(tx);
  196. f.gpu_blocks(bx);
  197. process_reduce(f, tx);
  198. } else {
  199. Var bx, tx;
  200. f.split(fused, bx, tx, max_threads_num, TailStrategy::GuardWithIf);
  201. f.gpu_threads(tx);
  202. f.gpu_blocks(bx);
  203. process_reduce(f, tx);
  204. }
  205. outputs.push_back(f);
  206. };
  207. auto schedule_reduce = [&process_reduce, &outputs,
  208. &device_prop](const AstNodePtr& output) {
  209. auto& f = output->m_func;
  210. auto& c = try_cast_as_op<ReduceOp>(output.get())->m_comp;
  211. auto vars = f.args();
  212. std::vector<Expr> exprs;
  213. Func real_out;
  214. for (auto var : vars) {
  215. exprs.emplace_back(var);
  216. }
  217. real_out(vars) = f(exprs);
  218. auto layout = output->m_layout;
  219. size_t total_nr_elems = layout.total_nr_elems();
  220. for (int i = layout.ndim - 1; i >= 0; i--) {
  221. real_out.bound(vars[layout.ndim - 1 - i], 0,
  222. static_cast<int>(layout[i]));
  223. }
  224. Var fused = vars[0];
  225. for (size_t i = 1; i < vars.size(); i++) {
  226. real_out.fuse(fused, vars[i], fused);
  227. }
  228. const int max_blocks = 65536;
  229. const int max_threads_num = device_prop.max_threads_per_block;
  230. bool need_block_split =
  231. total_nr_elems >
  232. static_cast<size_t>(max_blocks * max_threads_num);
  233. const int bt = max_blocks * max_threads_num;
  234. if (need_block_split) {
  235. Var xo, xi;
  236. Var bx, tx;
  237. real_out.split(fused, xo, xi, bt, TailStrategy::GuardWithIf);
  238. real_out.split(xi, bx, tx, Expr{max_threads_num},
  239. TailStrategy::GuardWithIf);
  240. real_out.reorder(xo, tx, bx);
  241. real_out.unroll(xo);
  242. real_out.gpu_threads(tx);
  243. real_out.gpu_blocks(bx);
  244. f.compute_at(real_out, tx);
  245. if (c.defined())
  246. c.compute_at(real_out, tx);
  247. process_reduce(real_out, tx);
  248. } else {
  249. Var bx, tx;
  250. real_out.split(fused, bx, tx, max_threads_num,
  251. TailStrategy::GuardWithIf);
  252. real_out.gpu_threads(tx);
  253. real_out.gpu_blocks(bx);
  254. f.compute_at(real_out, tx);
  255. if (c.defined())
  256. c.compute_at(real_out, tx);
  257. process_reduce(real_out, tx);
  258. }
  259. outputs.push_back(real_out);
  260. };
  261. if (dst_output->same_type<ReduceOp>()) {
  262. schedule_reduce(dst_output);
  263. } else {
  264. schedule_elemwise_like(dst_output);
  265. }
  266. return Pipeline(outputs);
  267. #else
  268. return Pipeline(dst_output->m_func);
  269. #endif
  270. }
  271. /* ==================== HalideCudaCompiler ===================== */
  272. std::unique_ptr<Executable> HalideCudaCompiler::do_compile(
  273. const InternalGraph& graph, const JITExecutor::Args& args) {
  274. return std::make_unique<HalideExecutable>(m_trait, graph, args);
  275. }
  276. const std::string& HalideCudaCompiler::cuda_runtime_lib() {
  277. static const char* const source = R"(
  278. #include <cuda.h>
  279. #include <cstdio>
  280. #include <cstdlib>
  281. namespace {
  282. struct HalideUserContext {
  283. CUcontext ctx;
  284. CUstream strm;
  285. };
  286. HalideUserContext* check_user_context(void* user_context) {
  287. if (!user_context) {
  288. fprintf(stderr, "user_context not provided\n");
  289. abort();
  290. }
  291. return static_cast<HalideUserContext*>(user_context);
  292. }
  293. } // anonymous namespace
  294. extern "C" int halide_cuda_acquire_context(void* user_context, CUcontext* ctx,
  295. bool create) {
  296. if (!user_context && !create) {
  297. // called from halide_cuda_cleanup()
  298. return 1;
  299. }
  300. *ctx = check_user_context(user_context)->ctx;
  301. return 0;
  302. }
  303. extern "C" int halide_cuda_release_context(void* user_context) {
  304. return 0;
  305. }
  306. extern "C" int halide_cuda_get_stream(void* user_context, CUcontext ctx,
  307. CUstream* stream) {
  308. *stream = check_user_context(user_context)->strm;
  309. return 0;
  310. }
  311. )";
  312. static std::string name =
  313. ExecutableHelper::get().compile_cpp_source_secondary(
  314. source, "halide_cuda_runtime_override");
  315. return name;
  316. }
  317. #endif // MGB_JIT_HALIDE && MGB_CUDA
  318. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台