You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mlir_gen.cpp 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. /**
  2. * \file src/jit/impl/mlir/mlir_gen.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain_build_config.h"
  13. #if MGB_JIT && MGB_JIT_MLIR
  14. #include "./mlir_gen.h"
  15. #include "./ir/each_mode.h"
  16. #include "./ir/types.h"
  17. #include "megbrain/jit/mlir/ir/dialect.h"
  18. #include "megbrain/jit/mlir/ir/utils.h"
  19. #include "megbrain/opr/basic_arith.h"
  20. #include "megdnn/dtype.h"
  21. #include <mlir/Dialect/Affine/IR/AffineOps.h>
  22. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  23. #include <mlir/IR/Attributes.h>
  24. #include <mlir/IR/Builders.h>
  25. #include <mlir/IR/Function.h>
  26. #include <mlir/IR/MLIRContext.h>
  27. #include <mlir/IR/Module.h>
  28. #include <mlir/IR/StandardTypes.h>
  29. #include <mlir/IR/Types.h>
  30. #include <mlir/IR/Value.h>
  31. #include <mlir/IR/Verifier.h>
  32. #include <mlir/Support/LogicalResult.h>
  33. #include <llvm/ADT/ScopedHashTable.h>
  34. #include <llvm/Support/raw_ostream.h>
  35. using namespace mgb;
  36. using namespace jit;
  37. namespace {
  38. class MLIRGenImpl {
  39. public:
  40. MLIRGenImpl(mlir::MLIRContext& context) : m_builder(&context) {}
  41. std::pair<llvm::StringRef, mlir::OwningModuleRef> gen(
  42. const InternalGraph& internal_graph,
  43. const JITExecutor::Args& args) {
  44. mlir::ModuleOp module =
  45. mlir::ModuleOp::create(m_builder.getUnknownLoc());
  46. //! Create main routine function
  47. auto func_op = gen_func_op(internal_graph, args);
  48. module.push_back(func_op);
  49. if (mlir::failed(mlir::verify(module))) {
  50. module.emitError("module verification error");
  51. return {};
  52. }
  53. return {func_op.getName(), module};
  54. }
  55. private:
  56. mlir::OpBuilder m_builder;
  57. llvm::ScopedHashTable<mlir::StringRef, mlir::Value> m_symbol_table;
  58. mlir::FuncOp gen_func_op(const InternalGraph& internal_graph,
  59. const JITExecutor::Args& args) {
  60. llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(
  61. m_symbol_table);
  62. std::vector<mlir::Type> func_args;
  63. for (auto&& arg : args.inputs) {
  64. func_args.push_back(get_type(arg.from->layout()));
  65. }
  66. for (auto&& arg : args.outputs) {
  67. func_args.push_back(get_type(arg.from->layout()));
  68. }
  69. //! the last arg is nr_elements
  70. func_args.push_back(m_builder.getIndexType());
  71. auto func_type = m_builder.getFunctionType(func_args, llvm::None);
  72. //! function name maybe renamed in later pass
  73. mlir::FuncOp func_op = mlir::FuncOp::create(m_builder.getUnknownLoc(),
  74. "func", func_type);
  75. if (!func_op)
  76. return nullptr;
  77. func_op.setAttr("llvm.emit_c_interface",
  78. mlir::UnitAttr::get(m_builder.getContext()));
  79. auto& entry_block = *func_op.addEntryBlock();
  80. size_t idx = 0;
  81. for (auto&& input : args.inputs) {
  82. if (mlir::failed(declare(internal_graph.placeholders()[input.idx]
  83. ->output(0)
  84. ->name(),
  85. entry_block.getArgument(idx)))) {
  86. return nullptr;
  87. }
  88. idx++;
  89. }
  90. for (auto&& output : args.outputs) {
  91. if (mlir::failed(declare(output.from->name(),
  92. entry_block.getArgument(idx)))) {
  93. return nullptr;
  94. }
  95. idx++;
  96. }
  97. m_builder.setInsertionPointToStart(&entry_block);
  98. if (mlir::failed(gen_func_body(internal_graph, args))) {
  99. func_op.erase();
  100. return nullptr;
  101. }
  102. dialect::ReturnOp return_op;
  103. if (!return_op) {
  104. m_builder.create<dialect::ReturnOp>(m_builder.getUnknownLoc());
  105. }
  106. std::string op_content = mlir_type_to_string(func_op);
  107. func_op.setName(
  108. ssprintf("jit_mlir_%" PRIx64,
  109. XXHash{}.update(op_content.data(), op_content.size())
  110. .digest()));
  111. return func_op;
  112. }
  113. mlir::LogicalResult gen_func_body(const InternalGraph& internal_graph,
  114. const JITExecutor::Args& args) {
  115. llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(
  116. m_symbol_table);
  117. cg::DepOprIter{[&](cg::OperatorNodeBase* opr) {
  118. if (opr->same_type<JITPlaceholder>()) {
  119. return;
  120. } else if (opr->same_type<opr::ImmutableTensor>()) {
  121. auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar();
  122. if (imm.valid()) {
  123. auto dtype = imm->dtype();
  124. float scalar_value;
  125. if (dtype == dtype::Float32()) {
  126. scalar_value = imm->get<float>();
  127. } else {
  128. mgb_throw(InternalError,
  129. "mlir backend currently only support f32 "
  130. "dtype, but got %s",
  131. dtype.name());
  132. }
  133. auto&& out = m_builder.create<dialect::ConstantScalarOp>(
  134. m_builder.getUnknownLoc(), m_builder.getF32Type(),
  135. m_builder.getF32FloatAttr(scalar_value));
  136. mgb_assert(mlir::succeeded(
  137. declare(opr->output(0)->name(), out)));
  138. }
  139. } else if (opr->same_type<opr::Elemwise>()) {
  140. auto&& out = gen_elemwise(opr->cast_final<opr::Elemwise>());
  141. mgb_assert(
  142. mlir::succeeded(declare(opr->output(0)->name(), out)));
  143. return;
  144. } else if (opr->same_type<opr::TypeCvt>()) {
  145. auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>());
  146. mgb_assert(
  147. mlir::succeeded(declare(opr->output(0)->name(), out)));
  148. }
  149. }}
  150. .add(internal_graph.output());
  151. m_builder.create<dialect::AssignOp>(m_builder.getUnknownLoc(),
  152. get(internal_graph.output()),
  153. get(args.outputs[0].from));
  154. return mlir::success();
  155. }
  156. mlir::Value gen_elemwise(const opr::Elemwise& opr) {
  157. llvm::SmallVector<mlir::Value, 4> operands;
  158. for (size_t i = 0; i < opr.input().size(); i++) {
  159. operands.push_back(get(opr.input(i)));
  160. }
  161. mlir::Type res_type = deduce_elemwise_res_type(operands);
  162. return m_builder.create<dialect::Elemwise>(
  163. m_builder.getUnknownLoc(), res_type, mlir::ValueRange(operands),
  164. opr.param().mode);
  165. }
  166. mlir::Value gen_typecvt(const opr::TypeCvt& opr) {
  167. auto shape = get(opr.input(0))
  168. .getType()
  169. .dyn_cast_or_null<mlir::MemRefType>()
  170. .getShape();
  171. auto res_type = mlir::MemRefType::get(
  172. shape,
  173. megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext()));
  174. return m_builder.create<dialect::TypeCvt>(
  175. m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
  176. opr.input(0)->dtype(), opr.param());
  177. }
  178. mlir::Type get_type(const TensorLayout& layout) {
  179. return layout_to_mlir_type(layout, m_builder);
  180. }
  181. mlir::Value get(const VarNode* var) {
  182. if (auto ret = m_symbol_table.lookup(var->name())) {
  183. return ret;
  184. }
  185. mgb_throw(InternalError, "Unknown var: %s", var->cname());
  186. }
  187. mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) {
  188. if (m_symbol_table.count(var)) {
  189. return mlir::failure();
  190. }
  191. m_symbol_table.insert(var, value);
  192. return mlir::success();
  193. }
  194. };
  195. } // namespace
  196. std::pair<llvm::StringRef, mlir::OwningModuleRef> mgb::jit::mlir_gen(
  197. mlir::MLIRContext& context,
  198. const mgb::jit::InternalGraph& internal_graph,
  199. const mgb::jit::JITExecutor::Args& args) {
  200. return MLIRGenImpl(context).gen(internal_graph, args);
  201. }
  202. #endif // MGB_JIT && MGB_JIT_MLIR
  203. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台